lossL1 = criterionL1(gen_out, targets) lossL1.backward() optimizerG.step() lossL1viz = lossL1.item() L1_accum += lossL1viz if i == len(trainLoader) - 1: logline = "Epoch: {}, batch-idx: {}, L1: {}\n".format( epoch, i, lossL1viz) print(logline) # validation netG.eval() L1val_accum = 0.0 for i, validata in enumerate(valiLoader, 0): inputs_cpu, targets_cpu = validata targets_cpu, inputs_cpu = targets_cpu.float().cuda(), inputs_cpu.float( ).cuda() inputs.data.resize_as_(inputs_cpu).copy_(inputs_cpu) targets.data.resize_as_(targets_cpu).copy_(targets_cpu) outputs = netG(inputs) outputs_cpu = outputs.data.cpu().numpy() lossL1 = criterionL1(outputs, targets) L1val_accum += lossL1.item() if i == 0:
def getModel(expo): netG = TurbNetG(channelExponent=expo).to(device) netG.load_state_dict(torch.load(f'models/model_w_{expo}', map_location=device)) netG.eval() return netG