from Scripts.Algorithm import train, evaluateMARLNonLocal, evaluateMARLLocal from Scripts.Parameters import ParseInput import time import numpy as np import matplotlib.pyplot as plt import os if __name__ == '__main__': args = ParseInput() t0 = time.time() indexN = 0 valueLocalArray = np.zeros(args.numN) valueLocalArraySD = np.zeros(args.numN) valueNonLocalArray = np.zeros(args.numN) valueNonLocalArraySD = np.zeros(args.numN) ErrorArray = np.zeros(args.numN) ErrorArraySD = np.zeros(args.numN) NVec = np.zeros(args.numN) if args.train: print('Training is in progress.') train(args) print('Evaluation is in progress.') while indexN < args.numN: N = args.minN + indexN * args.divN NVec[indexN] = N for _ in range(0, args.maxSeed): valueLocal = evaluateMARLLocal(args, N) valueLocal = np.array(valueLocal.detach()) valueLocalArray[indexN] += valueLocal/args.maxSeed valueLocalArraySD[indexN] += valueLocal ** 2 / args.maxSeed valueNonLocal = evaluateMARLNonLocal(args, N) valueNonLocal = np.array(valueNonLocal.detach()) valueNonLocalArray[indexN] += valueNonLocal/args.maxSeed valueNonLocalArraySD[indexN] += valueNonLocal**2/args.maxSeed Error = np.abs(valueNonLocal - valueLocal) ErrorArray[indexN] += Error/args.maxSeed ErrorArraySD[indexN] += Error**2/args.maxSeed indexN += 1 print(f'N: {N}') valueLocalArraySD = np.sqrt(np.maximum(0, valueLocalArraySD - valueLocalArray ** 2)) valueNonLocalArraySD = np.sqrt(np.maximum(0, valueNonLocalArraySD - valueNonLocalArray ** 2)) ErrorArraySD = np.sqrt(np.maximum(0, ErrorArraySD - ErrorArray ** 2)) if not os.path.exists('Results'): os.mkdir('Results') plt.figure() plt.xlabel('N') plt.ylabel('Values') plt.plot(NVec, valueLocalArray, label='Local') plt.fill_between(NVec, valueLocalArray - valueLocalArraySD, valueLocalArray + valueLocalArraySD, alpha=0.3) plt.plot(NVec, valueNonLocalArray, label='Non-Local') plt.fill_between(NVec, valueNonLocalArray - valueNonLocalArraySD, valueNonLocalArray + valueNonLocalArraySD, alpha=0.3) plt.legend() plt.savefig(f'Results/Values.png') plt.figure() plt.xlabel('N') plt.ylabel('Error') plt.plot(NVec, ErrorArray) plt.fill_between(NVec, ErrorArray - ErrorArraySD, ErrorArray + ErrorArraySD, alpha=0.3) plt.savefig(f'Results/Error.png') t1 = time.time() print(f'Elapsed time is {t1-t0} sec')