Example #1
0
def train(opt, Gs, Zs, reals, NoiseAmp):
    real_ = functions.read_image(opt)
    in_s = 0
    scale_num = 0
    real = imresize(real_, opt.scale1, opt)
    reals = functions.creat_reals_pyramid(real, reals, opt)
    nfc_prev = 0

    while scale_num < opt.stop_scale + 1:
        opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)
        opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)),
                          128)

        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        #plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)
        #plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)
        plt.imsave('%s/real_scale.png' % (opt.outf),
                   functions.convert_image_np(reals[scale_num]),
                   vmin=0,
                   vmax=1)

        D_curr, G_curr = init_models(opt)
        if (nfc_prev == opt.nfc):
            G_curr.load_state_dict(
                torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))
            D_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))

        z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs,
                                                  Zs, in_s, NoiseAmp, opt)

        G_curr = functions.reset_grads(G_curr, False)
        G_curr.eval()
        D_curr = functions.reset_grads(D_curr, False)
        D_curr.eval()

        Gs.append(G_curr)
        Zs.append(z_curr)
        NoiseAmp.append(opt.noise_amp)

        torch.save(Zs, '%s/Zs.pth' % (opt.out_))
        torch.save(Gs, '%s/Gs.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))

        scale_num += 1
        nfc_prev = opt.nfc
        del D_curr, G_curr
    return
Example #2
0
                accuracy_valid.append(valid_model(dev_sents1_input, dev_sents2_input, dev_y_input))

            #dirty code to correctly asses validation accuracy, last results in the array are predictions for the padding rows and can be dumped afterwards
            this_validation_accuracy = numpy.concatenate(accuracy_valid)[0:n_dev_samples].sum()/float(n_dev_samples)

            if this_validation_accuracy > best_validation_accuracy:
                print("Train loss, "+str( (train_loss/hyperparas["valid_freq"]))+", validation accuracy: "+str(this_validation_accuracy*100)+"%")
                best_validation_accuracy = this_validation_accuracy

                # test it
                accuracy_test= []
                for minibatch_test_index in range(n_test_batches):
                    sents1_input = test_sents1_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size,0:test_lens1[(minibatch_test_index+1)*batch_size-1]]
                    sents2_input = test_sents2_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size,0:test_lens2[(minibatch_test_index+1)*batch_size-1]]
                    y_input = test_y_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size]
                    accuracy_test.append(test_model(sents1_input, sents2_input, y_input))
                this_test_accuracy = numpy.concatenate(accuracy_test)[0:n_test_samples].sum()/float(n_test_samples)
                print("Test accuracy: "+str(this_test_accuracy*100)+"%")

            train_loss=0
        batch_counter+=1

    if hyperparas["adagrad_reset"] > 0:
        if epoch % hyperparas["adagrad_reset"] == 0:
            utils.reset_grads(accumulated_grads)

    print("Epoch "+str(epoch)+" finished.")



                y_input = dev_y_extended[minibatch_dev_index*batch_size:(minibatch_dev_index+1)*batch_size]
                accuracy_valid.append(valid_model(x_input,y_input))

            #dirty code to correctly asses validation accuracy, last results in the array are predictions for the padding rows and can be dumped afterwards
            this_validation_accuracy = numpy.concatenate(accuracy_valid)[0:n_dev_samples].sum()/float(n_dev_samples)

            if this_validation_accuracy > best_validation_accuracy:
                print("Train loss, "+str( (train_accuracy/hyperparas["valid_freq"]))+", validation accuracy: "+str(this_validation_accuracy*100)+"%")
                best_validation_accuracy = this_validation_accuracy

                # test it
                accuracy_test= []
                for minibatch_test_index in range(n_test_batches):
                    x_input = test_x_indexes_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size,0:test_lengths[(minibatch_test_index+1)*batch_size-1]]
                    y_input = test_y_extended[minibatch_test_index*batch_size:(minibatch_test_index+1)*batch_size]
                    accuracy_test.append(test_model(x_input,y_input))
                this_test_accuracy = numpy.concatenate(accuracy_test)[0:n_test_samples].sum()/float(n_test_samples)
                print("Test accuracy: "+str(this_test_accuracy*100)+"%")

            train_accuracy=0
        batch_counter+=1

    if hyperparas["adagrad_reset"] > 0:
        if epoch % hyperparas["adagrad_reset"] == 0:
            utils.reset_grads(accumulated_grads)

    print("Epoch "+str(epoch)+" finished.")