Exemplo n.º 1
0
    num_accums = config.num_traj
    for batch_idx in range(config.train_size //
                           (config.batch_size * config.num_traj)):
        optimizer.zero_grad()
        loss_sum = Variable(torch.zeros(1)).cuda().data
        for _ in range(num_accums):
            for k in data_labels_paths.keys():
                tick = time.time()

                data, labels = next(train_gen_objs[k])

                print('fetch data cost ' + str(time.time() - tick) + 'sec')
                tick = time.time()

                data = data[:, :, 0:config.top_k + 1, :, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).cuda()
                data = Variable(torch.from_numpy(data)).cuda()
                labels = Variable(torch.from_numpy(labels)).cuda()
                data = data.permute(1, 0, 2, 3, 4, 5)

                # forward pass
                outputs = imitate_net([data, one_hot_labels, k])

                loss = losses_joint(outputs, labels, time_steps=k + 1) / types_prog / \
                       num_accums
                loss.backward()
                loss_sum += loss.data

                print('train one batch cost' + str(time.time() - tick) + 'sec')
Exemplo n.º 2
0
def train_model(csgnet, train_dataset, val_dataset, max_epochs=None):
    if max_epochs is None:
        epochs = 100
    else:
        epochs = max_epochs

    optimizer = optim.Adam(
        [para for para in csgnet.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               lr_decay_epoch=3,
                               patience=config.patience)

    best_state_dict = None
    patience = 3
    prev_test_loss = 1e20
    prev_test_reward = 0
    num_worse = 0
    for epoch in range(100):
        train_loss = 0
        Accuracies = []
        csgnet.train()
        # Number of times to accumulate gradients
        num_accums = config.num_traj
        batch_idx = 0
        count = 0
        for batch in train_dataset:
            labels = np.stack([x[0] for x in batch])
            data = np.stack([x[1] for x in batch])
            if not len(labels) == config.batch_size:
                continue
            optimizer.zero_grad()
            loss_sum = Variable(torch.zeros(1)).cuda().data

            one_hot_labels = prepare_input_op(labels, len(unique_draws))
            one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda()
            data = Variable(
                torch.from_numpy(data)).cuda().unsqueeze(-1).float()
            labels = Variable(torch.from_numpy(labels)).cuda()

            # forward pass
            outputs = csgnet.forward2([data, one_hot_labels, max_len])

            loss = losses_joint(outputs, labels,
                                time_steps=max_len + 1) / num_accums
            loss.backward()
            loss_sum += loss.data

            batch_idx += 1
            count += len(data)

            if batch_idx % num_accums == 0:
                # Clip the gradient to fixed value to stabilize training.
                torch.nn.utils.clip_grad_norm_(csgnet.parameters(), 20)
                optimizer.step()
                l = loss_sum
                train_loss += l
                # print(f'train loss batch {batch_idx}: {l}')

        mean_train_loss = (train_loss * num_accums) / (count //
                                                       config.batch_size)
        print(f'train loss epoch {epoch}: {float(mean_train_loss)}')
        del data, loss, loss_sum, train_loss, outputs

        test_losses = 0
        acc = 0
        csgnet.eval()
        test_reward = 0
        batch_idx = 0
        count = 0
        for batch in val_dataset:
            labels = np.stack([x[0] for x in batch])
            data = np.stack([x[1] for x in batch])
            if not len(labels) == config.batch_size:
                continue
            parser = ParseModelOutput(unique_draws,
                                      stack_size=(max_len + 1) // 2 + 1,
                                      steps=max_len,
                                      canvas_shape=[64, 64, 64],
                                      primitives=primitives)
            with torch.no_grad():
                one_hot_labels = prepare_input_op(labels, len(unique_draws))
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).cuda()
                data = Variable(
                    torch.from_numpy(data)).cuda().unsqueeze(-1).float()
                labels = Variable(torch.from_numpy(labels)).cuda()

                test_output = csgnet.forward2([data, one_hot_labels, max_len])

                l = losses_joint(test_output, labels,
                                 time_steps=max_len + 1).data
                test_losses += l
                acc += float((torch.argmax(torch.stack(test_output), dim=2).permute(1, 0) == labels).float().sum()) \
                / (labels.shape[0] * labels.shape[1])

                test_output = csgnet.test2(data, max_len)

                stack, _, _ = parser.get_final_canvas(
                    test_output,
                    if_pred_images=True,
                    if_just_expressions=False)
                data_ = data.squeeze().cpu().numpy()
                R = np.sum(np.logical_and(stack, data_),
                           (1, 2, 3)) / (np.sum(np.logical_or(stack, data_),
                                                (1, 2, 3)) + 1)
                test_reward += np.sum(R)

            batch_idx += 1
            count += len(data)

        test_reward = test_reward / count

        test_loss = test_losses / (count // config.batch_size)
        acc = acc / (count // config.batch_size)

        if test_loss < prev_test_loss:
            prev_test_loss = test_loss
            best_state_dict = csgnet.state_dict()
            num_worse = 0
        else:
            num_worse += 1
        if num_worse >= patience:
            csgnet.load_state_dict(best_state_dict)
            break

        print(f'test loss epoch {epoch}: {float(test_loss)}')
        print(f'test IOU epoch {epoch}: {test_reward}')
        print(f'test acc epoch {epoch}: {acc}')
        if config.if_schedule:
            reduce_plat.reduce_on_plateu(-test_reward)

        del test_losses, test_output
        if test_reward > prev_test_reward:
            prev_test_reward = test_reward