def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 scale_num = 0 real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) nfc_prev = 0 while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr return
accuracy_valid.append(valid_model(dev_sents1_input, dev_sents2_input, dev_y_input)) #dirty code to correctly asses validation accuracy, last results in the array are predictions for the padding rows and can be dumped afterwards this_validation_accuracy = numpy.concatenate(accuracy_valid)[0:n_dev_samples].sum()/float(n_dev_samples) if this_validation_accuracy > best_validation_accuracy: print("Train loss, "+str( (train_loss/hyperparas["valid_freq"]))+", validation accuracy: "+str(this_validation_accuracy*100)+"%") best_validation_accuracy = this_validation_accuracy # test it accuracy_test= [] for minibatch_test_index in range(n_test_batches): sents1_input = test_sents1_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size,0:test_lens1[(minibatch_test_index+1)*batch_size-1]] sents2_input = test_sents2_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size,0:test_lens2[(minibatch_test_index+1)*batch_size-1]] y_input = test_y_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size] accuracy_test.append(test_model(sents1_input, sents2_input, y_input)) this_test_accuracy = numpy.concatenate(accuracy_test)[0:n_test_samples].sum()/float(n_test_samples) print("Test accuracy: "+str(this_test_accuracy*100)+"%") train_loss=0 batch_counter+=1 if hyperparas["adagrad_reset"] > 0: if epoch % hyperparas["adagrad_reset"] == 0: utils.reset_grads(accumulated_grads) print("Epoch "+str(epoch)+" finished.")
y_input = dev_y_extended[minibatch_dev_index*batch_size:(minibatch_dev_index+1)*batch_size] accuracy_valid.append(valid_model(x_input,y_input)) #dirty code to correctly asses validation accuracy, last results in the array are predictions for the padding rows and can be dumped afterwards this_validation_accuracy = numpy.concatenate(accuracy_valid)[0:n_dev_samples].sum()/float(n_dev_samples) if this_validation_accuracy > best_validation_accuracy: print("Train loss, "+str( (train_accuracy/hyperparas["valid_freq"]))+", validation accuracy: "+str(this_validation_accuracy*100)+"%") best_validation_accuracy = this_validation_accuracy # test it accuracy_test= [] for minibatch_test_index in range(n_test_batches): x_input = test_x_indexes_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size,0:test_lengths[(minibatch_test_index+1)*batch_size-1]] y_input = test_y_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size] accuracy_test.append(test_model(x_input,y_input)) this_test_accuracy = numpy.concatenate(accuracy_test)[0:n_test_samples].sum()/float(n_test_samples) print("Test accuracy: "+str(this_test_accuracy*100)+"%") train_accuracy=0 batch_counter+=1 if hyperparas["adagrad_reset"] > 0: if epoch % hyperparas["adagrad_reset"] == 0: utils.reset_grads(accumulated_grads) print("Epoch "+str(epoch)+" finished.")