def main(FG): vis = Visdom(port=10001, env=str(FG.vis_env)) vis.text(argument_report(FG, end='<br>'), win='config') FG.global_step=0 cae = CAE().cuda() print_model_parameters(cae) #criterion = nn.BCELoss() criterion = nn.MSELoss() optimizer = optim.Adam(cae.parameters(), lr=FG.lr, betas=(0.5, 0.999)) schedular = torch.optim.lr_scheduler.ExponentialLR(optimizer, FG.lr_gamma) printers = dict( loss = summary.Scalar(vis, 'loss', opts=dict( showlegend=True, title='loss', ytickmin=0, ytinkmax=2.0)), lr = summary.Scalar(vis, 'lr', opts=dict( showlegend=True, title='lr', ytickmin=0, ytinkmax=2.0)), input_printer = summary.Image3D(vis, 'input') output_printer = summary.Image3D(vis, 'output')) trainloader, validloader = make_dataloader(FG) z = 256 batchSize = FG.batch_size imageSize = 64 input = torch.FloatTensor(batchSize, 1, imageSize, imageSize, imageSize).cuda() noise = torch.FloatTensor(batchSize, z).cuda() fixed_noise = torch.FloatTensor(batchSize, z).normal_(0, 1).cuda() label = torch.FloatTensor(batchSize).cuda() real_label = 1 fake_label = 0 for epoch in range(FG.num_epoch): schedular.step() torch.set_grad_enabled(True) pbar = tqdm(total=len(trainloader), desc='Epoch {:>3}'.format(epoch)) for i, data in enumerate(trainloader): real = data[0][0].cuda() output = cae(real) loss = criterion(output, real) loss.backward() optimizer.step() FG.global_step += 1 printers['loss']('loss', FG.global_step/len(trainloader), loss) printers['input']('input', real) printers['output']('output', output/output.max()) pbar.update() pbar.close()
depth_csv = False file_end_inp = '.jpg' file_end_outp = '_mask.gif' input_channels = 1 use_tresh_hold = False # save_path =None use_sync_data = False if use_sync_data: x_sz = y_sz = 64 else: x_sz = y_sz = 128 imget = imageGetter(inp_path, depth_csv, out_path, x_sz , y_sz,file_end_inp) autoEnc = CAE( x_sz , y_sz, save_path, on_home,input_channels) batch_size = 50 range_start = 0 epochs = 100 # # for i in range(100): # images_inp, images_outp = imget.getImageSubset(999+i, 1000+i) # autoEnc.test(images_inp[0], images_outp[0]) iteration =0 for i in range(epochs): print('running epoch ',i) for range_end in range(batch_size,len(imget.filelist), batch_size): # TODO get random image set images_inp, images_outp = imget.getImageSubset(range_start, batch_size, use_sync_data,file_end_outp) # images_inp, images_outp = imget.create_test_data(range_start, range_end)