def main(FG): vis = Visdom(port=10001, env=str(FG.vis_env)) vis.text(argument_report(FG, end='<br>'), win='config') FG.global_step=0 cae = CAE().cuda() print_model_parameters(cae) #criterion = nn.BCELoss() criterion = nn.MSELoss() optimizer = optim.Adam(cae.parameters(), lr=FG.lr, betas=(0.5, 0.999)) schedular = torch.optim.lr_scheduler.ExponentialLR(optimizer, FG.lr_gamma) printers = dict( loss = summary.Scalar(vis, 'loss', opts=dict( showlegend=True, title='loss', ytickmin=0, ytinkmax=2.0)), lr = summary.Scalar(vis, 'lr', opts=dict( showlegend=True, title='lr', ytickmin=0, ytinkmax=2.0)), input_printer = summary.Image3D(vis, 'input') output_printer = summary.Image3D(vis, 'output')) trainloader, validloader = make_dataloader(FG) z = 256 batchSize = FG.batch_size imageSize = 64 input = torch.FloatTensor(batchSize, 1, imageSize, imageSize, imageSize).cuda() noise = torch.FloatTensor(batchSize, z).cuda() fixed_noise = torch.FloatTensor(batchSize, z).normal_(0, 1).cuda() label = torch.FloatTensor(batchSize).cuda() real_label = 1 fake_label = 0 for epoch in range(FG.num_epoch): schedular.step() torch.set_grad_enabled(True) pbar = tqdm(total=len(trainloader), desc='Epoch {:>3}'.format(epoch)) for i, data in enumerate(trainloader): real = data[0][0].cuda() output = cae(real) loss = criterion(output, real) loss.backward() optimizer.step() FG.global_step += 1 printers['loss']('loss', FG.global_step/len(trainloader), loss) printers['input']('input', real) printers['output']('output', output/output.max()) pbar.update() pbar.close()