示例#1
0
文件: train.py 项目: chewnglow/pyGAT
def train(epoch):
    torch.autograd.set_detect_anomaly(True)
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, train_pos_adj)
    # output = model.encode(features, train_pos_edge_index)

    # loss & optim
    if args.link:
        loss_train = recon_loss(output, train_pos_edge_index)

        # TODO: add variational loss
    else:
        loss_train = F.nll_loss(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    # if use_nmf:
    #     Ws_new, Hs_new = nmf_optim(X, Ws, Hs)
    #     state_dict = model.state_dict()
    #     state_dict['W_list.weight'] = Ws_new
    #     state_dict['H_list.weight'] = Hs_new
    #     model.load_state_dict(state_dict)

    #     loss_train_nmf = nmf_loss(X, Ws, Hs)
    #     loss_train_nmf.backward()
    #     optim_nmf.step()

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        with torch.no_grad():
            output = model(features, adj)
            # output = model.encode(features, train_pos_edge_index)
    # TODO: autoencoder validation
    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    print(
        'Epoch: {:04d}'.format(epoch + 1),
        'loss_train: {:.4f}'.format(loss_train.data.item()),
        # 'acc_train: {:.4f}'.format(acc_train.data.item()),
        # 'loss_val: {:.4f}'.format(loss_val.data.item()),
        # 'acc_val: {:.4f}'.format(acc_val.data.item()),
        'time: {:.4f}s'.format(time.time() - t))

    return loss_val.data.item()
示例#2
0
def train(args):
    # Settings
    b, c, h, w = 1, 3, 256, 256
    beta1 = 0.5
    beta2 = 0.999
    pool_size = 50
    lambda_recon = args.lambda_recon
    lambda_idt = args.lambda_idt
    base_lr = args.learning_rate
    init_method = args.init_method

    # Context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Inputs
    x_raw = nn.Variable([b, c, h, w], need_grad=False)
    y_raw = nn.Variable([b, c, h, w], need_grad=False)
    x_real = image_augmentation(x_raw)
    y_real = image_augmentation(y_raw)
    x_history = nn.Variable([b, c, h, w])
    y_history = nn.Variable([b, c, h, w])
    x_real_test = nn.Variable([b, c, h, w], need_grad=False)
    y_real_test = nn.Variable([b, c, h, w], need_grad=False)

    # Models for training
    # Generate
    y_fake = models.g(x_real, unpool=args.unpool, init_method=init_method)
    x_fake = models.f(y_real, unpool=args.unpool, init_method=init_method)
    y_fake.persistent, x_fake.persistent = True, True
    # Reconstruct
    x_recon = models.f(y_fake, unpool=args.unpool, init_method=init_method)
    y_recon = models.g(x_fake, unpool=args.unpool, init_method=init_method)
    # Discriminate
    d_y_fake = models.d_y(y_fake, init_method=init_method)
    d_x_fake = models.d_x(x_fake, init_method=init_method)
    d_y_real = models.d_y(y_real, init_method=init_method)
    d_x_real = models.d_x(x_real, init_method=init_method)
    d_y_history = models.d_y(y_history, init_method=init_method)
    d_x_history = models.d_x(x_history, init_method=init_method)

    # Models for test
    y_fake_test = models.g(
        x_real_test, unpool=args.unpool, init_method=init_method)
    x_fake_test = models.f(
        y_real_test, unpool=args.unpool, init_method=init_method)
    y_fake_test.persistent, x_fake_test.persistent = True, True
    # Reconstruct
    x_recon_test = models.f(
        y_fake_test, unpool=args.unpool, init_method=init_method)
    y_recon_test = models.g(
        x_fake_test, unpool=args.unpool, init_method=init_method)

    # Losses
    # Reconstruction Loss
    loss_recon = models.recon_loss(x_recon, x_real) \
        + models.recon_loss(y_recon, y_real)
    # Generator loss
    loss_gen = models.lsgan_loss(d_y_fake) \
        + models.lsgan_loss(d_x_fake) \
        + lambda_recon * loss_recon
    # Identity loss
    if lambda_idt != 0:
        logger.info("Identity loss was added.")
        # Identity
        y_idt = models.g(y_real, unpool=args.unpool, init_method=init_method)
        x_idt = models.f(x_real, unpool=args.unpool, init_method=init_method)
        loss_idt = models.recon_loss(x_idt, x_real) \
            + models.recon_loss(y_idt, y_real)
        loss_gen += lambda_recon * lambda_idt * loss_idt
    # Discriminator losses
    loss_dis_y = models.lsgan_loss(d_y_history, d_y_real)
    loss_dis_x = models.lsgan_loss(d_x_history, d_x_real)

    # Solvers
    solver_gen = S.Adam(base_lr, beta1, beta2)
    solver_dis_x = S.Adam(base_lr, beta1, beta2)
    solver_dis_y = S.Adam(base_lr, beta1, beta2)
    with nn.parameter_scope('generator'):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope('discriminator'):
        with nn.parameter_scope("x"):
            solver_dis_x.set_parameters(nn.get_parameters())
        with nn.parameter_scope("y"):
            solver_dis_y.set_parameters(nn.get_parameters())

    # Datasets
    rng = np.random.RandomState(313)
    ds_train_B = cycle_gan_data_source(
        args.dataset, train=True, domain="B", shuffle=True, rng=rng)
    ds_train_A = cycle_gan_data_source(
        args.dataset, train=True, domain="A", shuffle=True, rng=rng)
    ds_test_B = cycle_gan_data_source(
        args.dataset, train=False, domain="B", shuffle=False, rng=rng)
    ds_test_A = cycle_gan_data_source(
        args.dataset, train=False, domain="A", shuffle=False, rng=rng)
    di_train_B = cycle_gan_data_iterator(ds_train_B, args.batch_size)
    di_train_A = cycle_gan_data_iterator(ds_train_A, args.batch_size)
    di_test_B = cycle_gan_data_iterator(ds_test_B, args.batch_size)
    di_test_A = cycle_gan_data_iterator(ds_test_A, args.batch_size)

    # Monitors
    monitor = Monitor(args.monitor_path)

    def make_monitor(name):
        return MonitorSeries(name, monitor, interval=1)
    monitor_loss_gen = make_monitor('generator_loss')
    monitor_loss_dis_x = make_monitor('discriminator_B_domain_loss')
    monitor_loss_dis_y = make_monitor('discriminator_A_domain_loss')

    def make_monitor_image(name):
        return MonitorImage(name, monitor, interval=1,
                            normalize_method=lambda x: (x + 1.0) * 127.5)
    monitor_train_gx = make_monitor_image('fake_images_train_A')
    monitor_train_fy = make_monitor_image('fake_images_train_B')
    monitor_train_x_recon = make_monitor_image('fake_images_B_recon_train')
    monitor_train_y_recon = make_monitor_image('fake_images_A_recon_train')
    monitor_test_gx = make_monitor_image('fake_images_test_A')
    monitor_test_fy = make_monitor_image('fake_images_test_B')
    monitor_test_x_recon = make_monitor_image('fake_images_recon_test_B')
    monitor_test_y_recon = make_monitor_image('fake_images_recon_test_A')
    monitor_train_list = [
        (monitor_train_gx, y_fake),
        (monitor_train_fy, x_fake),
        (monitor_train_x_recon, x_recon),
        (monitor_train_y_recon, y_recon),
        (monitor_loss_gen, loss_gen),
        (monitor_loss_dis_x, loss_dis_x),
        (monitor_loss_dis_y, loss_dis_y),
    ]
    monitor_test_list = [
        (monitor_test_gx, y_fake_test),
        (monitor_test_fy, x_fake_test),
        (monitor_test_x_recon, x_recon_test),
        (monitor_test_y_recon, y_recon_test)]

    # ImagePool
    pool_x = ImagePool(pool_size)
    pool_y = ImagePool(pool_size)

    # Training loop
    epoch = 0
    n_images = np.max([ds_train_B.size, ds_train_A.size]
                      )  # num. images for each domain
    max_iter = args.max_epoch * n_images // args.batch_size
    for i in range(max_iter):
        # Validation
        if int((i+1) % (n_images // args.batch_size)) == 0:
            logger.info("Mode:Test,Epoch:{}".format(epoch))
            # Monitor for train
            for monitor, v in monitor_train_list:
                monitor.add(i, v.d)
            # Use training graph since there are no test mode
            x_data, _ = di_test_B.next()
            y_data, _ = di_test_A.next()
            x_real_test.d = x_data
            y_real_test.d = y_data
            x_recon_test.forward()
            y_recon_test.forward()
            # Monitor for test
            for monitor, v in monitor_test_list:
                monitor.add(i, v.d)
            # Save model
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
            # Learning rate decay
            for solver in [solver_gen, solver_dis_x, solver_dis_y]:
                linear_decay(solver, base_lr, epoch, args.max_epoch)
            epoch += 1

        # Get data
        x_data, _ = di_train_B.next()
        y_data, _ = di_train_A.next()
        x_raw.d = x_data
        y_raw.d = y_data

        # Train Generators
        loss_gen.forward(clear_no_need_grad=False)
        solver_gen.zero_grad()
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()

        # Insert and Get to/from pool
        x_history.d = pool_x.insert_then_get(x_fake.d)
        y_history.d = pool_y.insert_then_get(y_fake.d)

        # Train Discriminator Y
        loss_dis_y.forward(clear_no_need_grad=False)
        solver_dis_y.zero_grad()
        loss_dis_y.backward(clear_buffer=True)
        solver_dis_y.update()

        # Train Discriminator X
        loss_dis_x.forward(clear_no_need_grad=False)
        solver_dis_x.zero_grad()
        loss_dis_x.backward(clear_buffer=True)
        solver_dis_x.update()
示例#3
0
def train(args):
    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = args.batch_size, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])

    # Model
    # workaround for starting with the same model among devices.
    np.random.seed(412)
    maps = args.maps
    # within-domain reconstruction (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a")
    # within-domain reconstruction (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(shape=x_style_a.shape)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a")
    x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(shape=x_style_b.shape)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")
    x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b")
    x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b")
    # discriminate (domain A)
    p_x_fake_a_list = discriminators(x_fake_a)
    p_x_real_a_list = discriminators(x_real_a)
    p_x_fake_b_list = discriminators(x_fake_b)
    p_x_real_b_list = discriminators(x_real_b)

    # Loss
    # within-domain reconstruction
    loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True)
    loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True)
    # content and style reconstruction
    loss_recon_x_style_a = recon_loss(x_style_rec_a,
                                      z_style_a).apply(persistent=True)
    loss_recon_x_content_b = recon_loss(x_content_rec_b,
                                        x_content_b).apply(persistent=True)
    loss_recon_x_style_b = recon_loss(x_style_rec_b,
                                      z_style_b).apply(persistent=True)
    loss_recon_x_content_a = recon_loss(x_content_rec_a,
                                        x_content_a).apply(persistent=True)

    # adversarial

    def f(x, y):
        return x + y

    loss_gen_a = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_a_list]).apply(persistent=True)
    loss_dis_a = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list)
    ]).apply(persistent=True)
    loss_gen_b = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_b_list]).apply(persistent=True)
    loss_dis_b = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list)
    ]).apply(persistent=True)
    # loss for generator-related models
    loss_gen = loss_gen_a + loss_gen_b \
        + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \
        + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \
        + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b)
    # loss for discriminators
    loss_dis = loss_dis_a + loss_dis_b

    # Solver
    lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2
    # solver for generator-related models
    solver_gen = S.Adam(lr_g, beta1, beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
    solver_gen.set_parameters(params_gen)
    # solver for discriminators
    solver_dis = S.Adam(lr_d, beta1, beta2)
    with nn.parameter_scope("discriminators"):
        params_dis = nn.get_parameters()
    solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    # time
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    # reconstruction
    monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A",
                                                 monitor,
                                                 interval=10)
    monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B",
                                                 monitor,
                                                 interval=10)
    # adversarial
    monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10)
    monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10)
    monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10)
    monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10)
    monitor_losses = [
        # reconstruction
        (monitor_loss_recon_x_a, loss_recon_x_a),
        (monitor_loss_recon_x_content_b, loss_recon_x_content_b),
        (monitor_loss_recon_x_style_a, loss_recon_x_style_a),
        (monitor_loss_recon_x_b, loss_recon_x_b),
        (monitor_loss_recon_x_content_a, loss_recon_x_content_a),
        (monitor_loss_recon_x_style_b, loss_recon_x_style_b),
        # adaversarial
        (monitor_loss_gen_a, loss_gen_a),
        (monitor_loss_dis_a, loss_dis_a),
        (monitor_loss_gen_b, loss_gen_b),
        (monitor_loss_dis_b, loss_dis_b)
    ]
    # image
    monitor_image_a = MonitorImage("Fake Image B to A Train",
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B Train",
                                   monitor,
                                   interval=1)
    monitor_images = [
        (monitor_image_a, x_fake_a),
        (monitor_image_b, x_fake_b),
    ]

    # DataIterator
    rng_a = np.random.RandomState(device_id)
    rng_b = np.random.RandomState(device_id + n_devices)
    di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b)

    # Train
    for i in range(args.max_iter // n_devices):
        ii = i * n_devices
        # Train generator-related models
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_gen.values()])
        solver_gen.weight_decay(args.weight_decay_rate)
        solver_gen.update()

        # Train discriminators
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        x_fake_a.need_grad, x_fake_b.need_grad = False, False
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_dis.values()])
        solver_dis.weight_decay(args.weight_decay_rate)
        solver_dis.update()
        x_fake_a.need_grad, x_fake_b.need_grad = True, True

        # LR schedule
        if (i + 1) % (args.lr_decay_at_every // n_devices) == 0:
            lr_d = solver_dis.learning_rate() * args.lr_decay_rate
            lr_g = solver_gen.learning_rate() * args.lr_decay_rate
            solver_dis.set_learning_rate(lr_d)
            solver_gen.set_learning_rate(lr_g)

        if mpi_local_rank == 0:
            # Monitor
            monitor_time.add(ii)
            for mon, loss in monitor_losses:
                mon.add(ii, loss.d)
            # Save
            if (i + 1) % (args.model_save_interval // n_devices) == 0:
                for mon, x in monitor_images:
                    mon.add(ii, x.d)
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "param_{:05d}.h5".format(i)))

    if mpi_local_rank == 0:
        # Monitor
        for mon, loss in monitor_losses:
            mon.add(ii, loss.d)
        # Save
        for mon, x in monitor_images:
            mon.add(ii, x.d)
        nn.save_parameters(
            os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))