Ejemplo n.º 1
0
def train_inference(imitate_net,
                    path,
                    max_epochs=None,
                    self_training=False,
                    ab=None):
    if max_epochs is None:
        epochs = 1000
    else:
        epochs = max_epochs

    config = read_config.Config("config_synthetic.yml")

    if ab is not None:
        train_size = inference_train_size * ab
    else:
        train_size = inference_train_size

    generator = WakeSleepGen(f"{path}/",
                             batch_size=config.batch_size,
                             train_size=train_size,
                             canvas_shape=config.canvas_shape,
                             max_len=max_len,
                             self_training=True)

    train_gen = generator.get_train_data()

    cad_generator = Generator()
    val_gen = cad_generator.val_gen(batch_size=config.batch_size,
                                    path="data/cad/cad.h5",
                                    if_augment=False)

    for parameter in imitate_net.encoder.parameters():
        parameter.requires_grad = False

    optimizer = optim.Adam(
        [para for para in imitate_net.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               patience=config.patience)

    best_test_loss = 1e20
    torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth")

    best_test_cd = 1e20

    patience = 20
    num_worse = 0

    for epoch in range(epochs):
        start = time.time()
        train_loss = 0
        imitate_net.train()
        for batch_idx in range(train_size //
                               (config.batch_size * config.num_traj)):
            optimizer.zero_grad()
            loss = 0
            # acc = 0
            for _ in range(config.num_traj):
                data, labels = next(train_gen)
                # data = data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
                data = data.to(device)
                labels = labels.to(device)
                outputs = imitate_net([data, one_hot_labels, max_len])
                # acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \
                #        / (labels.shape[0] * labels.shape[1]) / config.num_traj
                loss_k = (
                    (losses_joint(outputs, labels, time_steps=max_len + 1) /
                     (max_len + 1)) / config.num_traj)
                loss_k.backward()
                loss += float(loss_k)
                del loss_k

            optimizer.step()
            train_loss += loss
            print(f"batch {batch_idx} train loss: {loss}")
            # print(f"acc: {acc}")

        mean_train_loss = train_loss / (train_size // (config.batch_size))
        print(f"epoch {epoch} mean train loss: {mean_train_loss}")
        imitate_net.eval()
        loss = 0
        # acc = 0
        metrics = {"cos": 0, "iou": 0, "cd": 0}
        # IOU = 0
        # COS = 0
        CD = 0
        # correct_programs = 0
        # pred_programs = 0
        for batch_idx in range(inference_test_size // config.batch_size):
            parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1,
                                      max_len, config.canvas_shape)
            with torch.no_grad():
                labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
                data_ = next(val_gen)
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = torch.from_numpy(one_hot_labels).cuda()
                data = torch.from_numpy(data_).cuda()
                # outputs = imitate_net([data, one_hot_labels, max_len])
                # loss_k = (losses_joint(outputs, labels, time_steps=max_len + 1) /
                #          (max_len + 1))
                # loss += float(loss_k)
                test_outputs = imitate_net.test(
                    [data[-1, :, 0, :, :], one_hot_labels, max_len])
                # acc += float((torch.argmax(torch.stack(test_outputs), dim=2).permute(1, 0) == labels[:, :-1]).float().sum()) \
                #         / (len(labels) * (max_len+1)) / (inference_test_size // config.batch_size)
                pred_images, correct_prog, pred_prog = parser.get_final_canvas(
                    test_outputs,
                    if_just_expressions=False,
                    if_pred_images=True)
                # correct_programs += len(correct_prog)
                # pred_programs += len(pred_prog)
                target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
                # iou = np.sum(np.logical_and(target_images, pred_images),
                #              (1, 2)) / \
                #       np.sum(np.logical_or(target_images, pred_images),
                #              (1, 2))
                # cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                # IOU += np.sum(iou)
                # COS += np.sum(cos)

        # metrics["iou"] = IOU / inference_test_size
        # metrics["cos"] = COS / inference_test_size
        metrics["cd"] = CD / inference_test_size

        test_losses = loss
        test_loss = test_losses / (inference_test_size // (config.batch_size))

        if metrics["cd"] >= best_test_cd:
            num_worse += 1
        else:
            num_worse = 0
            best_test_cd = metrics["cd"]
            torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth")
        if num_worse >= patience:
            # load the best model and stop training
            imitate_net.load_state_dict(torch.load(f"{path}/best_dict.pth"))
            return epoch + 1

        # reduce_plat.reduce_on_plateu(metrics["cd"])
        print(
            f"Epoch {epoch}/100 =>  train_loss: {mean_train_loss}, iou: {0}, cd: {metrics['cd']}, test_mse: {test_loss}, test_acc: {0}"
        )
        # print(f"CORRECT PROGRAMS: {correct_programs}")
        # print(f"PREDICTED PROGRAMS: {pred_programs}")
        # print(f"RATIO: {correct_programs/pred_programs}")

        end = time.time()
        print(f"Inference train time {end-start}")

        del test_losses, outputs, test_outputs

    return epochs
Ejemplo n.º 2
0
                print('fetch data cost ' + str(time.time() - tick) + 'sec')
                tick = time.time()

                data = data[:, :, 0:config.top_k + 1, :, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).cuda()
                data = Variable(torch.from_numpy(data)).cuda()
                labels = Variable(torch.from_numpy(labels)).cuda()
                data = data.permute(1, 0, 2, 3, 4, 5)

                # forward pass
                outputs = imitate_net([data, one_hot_labels, k])

                loss = losses_joint(outputs, labels, time_steps=k + 1) / types_prog / \
                       num_accums
                loss.backward()
                loss_sum += loss.data

                print('train one batch cost' + str(time.time() - tick) + 'sec')

        # Clip the gradient to fixed value to stabilize training.
        torch.nn.utils.clip_grad_norm(imitate_net.parameters(), 20)
        optimizer.step()
        l = loss_sum
        train_loss += l
        log_value(
            'train_loss_batch',
            l.cpu().numpy(),
            epoch * (config.train_size // (config.batch_size * num_accums)) +
Ejemplo n.º 3
0
        loss = Variable(torch.zeros(1)).cuda().data
        acc = 0
        for _ in range(config.num_traj):
            for k in dataset_sizes.keys():
                data, labels = next(train_gen_objs[k])
                data = data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).cuda()
                data = Variable(torch.from_numpy(data)).cuda()
                labels = Variable(torch.from_numpy(labels)).cuda()
                outputs = imitate_net([data, one_hot_labels, k])
                #acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \
                #       / (labels.shape[0] * labels.shape[1]) / types_prog / config.num_traj
                loss_k = (losses_joint(outputs, labels, time_steps=k + 1) /
                          (k + 1)) / len(
                              dataset_sizes.keys()) / config.num_traj
                loss_k.backward()
                loss += loss_k.data
                del loss_k

        optimizer.step()
        train_loss += loss
        print(f"batch {batch_idx} train loss: {loss.cpu().numpy()}")
        print(f"acc: {acc}")

    mean_train_loss = train_loss / (config.train_size // (config.batch_size))
    print(f"epoch {epoch} mean train loss: {mean_train_loss.cpu().numpy()}")
    imitate_net.eval()
    loss = Variable(torch.zeros(1)).cuda()
Ejemplo n.º 4
0
                print('fetch data cost ' + str(time.time() - tick) + 'sec')
                tick = time.time()

                data = data[:, :, 0:config.top_k + 1, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).cuda()
                data = Variable(torch.from_numpy(data)).cuda()
                labels = Variable(torch.from_numpy(labels)).cuda()
                data = data.permute(1, 0, 2, 3, 4)

                # forward pass
                outputs = imitate_net([data, one_hot_labels, k])

                loss = losses_joint(outputs, labels, time_steps=k + 1) / types_prog / \
                       num_accums
                loss.backward()
                loss_sum += loss.data

                print('train one batch cost' + str(time.time() - tick) + 'sec')

        # Clip the gradient to fixed value to stabilize training.
        torch.nn.utils.clip_grad_norm(imitate_net.parameters(), 20)
        optimizer.step()
        l = loss_sum
        train_loss += l
        log_value(
            'train_loss_batch',
            l.cpu().numpy(),
            epoch * (config.train_size // (config.batch_size * num_accums)) +
Ejemplo n.º 5
0
def train_inference(inference_net, iter):
    config = read_config.Config("config_synthetic.yml")

    generator = WakeSleepGen(
        f"wake_sleep_data/inference/{iter}/labels/labels.pt",
        f"wake_sleep_data/inference/{iter}/labels/val/labels.pt",
        batch_size=config.batch_size,
        train_size=inference_train_size,
        test_size=inference_test_size,
        canvas_shape=config.canvas_shape,
        max_len=max_len)

    train_gen = generator.get_train_data()
    test_gen = generator.get_test_data()

    encoder_net, imitate_net = inference_net

    optimizer = optim.Adam(
        [para for para in imitate_net.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               patience=config.patience)

    best_test_loss = 1e20
    best_imitate_dict = imitate_net.state_dict()

    prev_test_cd = 1e20
    prev_test_iou = 0

    patience = 5
    num_worse = 0

    for epoch in range(50):
        train_loss = 0
        Accuracies = []
        imitate_net.train()
        for batch_idx in range(inference_train_size //
                               (config.batch_size * config.num_traj)):
            optimizer.zero_grad()
            loss = Variable(torch.zeros(1)).to(device).data
            for _ in range(config.num_traj):
                batch_data, batch_labels = next(train_gen)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                batch_data = batch_data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(batch_labels, vocab_size)
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).to(device)
                outputs = imitate_net([batch_data, one_hot_labels, max_len])

                loss_k = (losses_joint(
                    outputs, batch_labels, time_steps=max_len + 1) /
                          (max_len + 1)) / config.num_traj
                loss_k.backward()
                loss += loss_k.data
                del loss_k

            optimizer.step()
            train_loss += loss
            print(f"batch {batch_idx} train loss: {loss.cpu().numpy()}")

        mean_train_loss = train_loss / (inference_train_size //
                                        (config.batch_size))
        print(
            f"epoch {epoch} mean train loss: {mean_train_loss.cpu().numpy()}")
        imitate_net.eval()
        loss = Variable(torch.zeros(1)).to(device)
        metrics = {"cos": 0, "iou": 0, "cd": 0}
        IOU = 0
        COS = 0
        CD = 0
        for batch_idx in range(inference_test_size // config.batch_size):
            with torch.no_grad():
                batch_data, batch_labels = next(test_gen)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                one_hot_labels = prepare_input_op(batch_labels, vocab_size)
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).to(device)
                test_outputs = imitate_net(
                    [batch_data, one_hot_labels, max_len])
                loss += (losses_joint(
                    test_outputs, batch_labels, time_steps=max_len + 1) /
                         (max_len + 1))
                test_output = imitate_net.test(
                    [batch_data, one_hot_labels, max_len])
                pred_images, correct_prog, pred_prog = generator.parser.get_final_canvas(
                    test_output,
                    if_just_expressions=False,
                    if_pred_images=True)
                target_images = batch_data.cpu().numpy()[-1, :,
                                                         0, :, :].astype(
                                                             dtype=bool)
                iou = np.sum(np.logical_and(target_images, pred_images),
                             (1, 2)) / \
                      np.sum(np.logical_or(target_images, pred_images),
                             (1, 2))
                cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                IOU += np.sum(iou)
                COS += np.sum(cos)

        metrics["iou"] = IOU / inference_test_size
        metrics["cos"] = COS / inference_test_size
        metrics["cd"] = CD / inference_test_size

        test_losses = loss.data
        test_loss = test_losses.cpu().numpy() / (inference_test_size //
                                                 (config.batch_size))

        if test_loss >= best_test_loss:
            num_worse += 1
        else:
            num_worse = 0
            best_test_loss = test_loss
            best_imitate_dict = imitate_net.state_dict()
        if num_worse >= patience:
            # load the best model and stop training
            imitate_net.load_state_dict(best_imitate_dict)
            break

        reduce_plat.reduce_on_plateu(metrics["cd"])
        print("Epoch {}/{}=>  train_loss: {}, iou: {}, cd: {}, test_mse: {}".
              format(
                  epoch,
                  config.epochs,
                  mean_train_loss.cpu().numpy(),
                  metrics["iou"],
                  metrics["cd"],
                  test_loss,
              ))

        print(f"CORRECT PROGRAMS: {len(generator.correct_programs)}")

        del test_losses, test_outputs
Ejemplo n.º 6
0
                perturbs = torch.from_numpy(perturbs).to(device)
                perturb_out = perturb_out.permute(1, 0, 2)

                # mask off ops and stop token
                perturb_loss = F.mse_loss(
                    perturbs[labels < 396], perturb_out[labels < 396]) / len(
                        dataset_sizes.keys()) / config.num_traj
                #perturb_loss = F.mse_loss(perturbs, perturb_out) / len(dataset_sizes.keys()) / config.num_traj
                if not imitate_net.tf:
                    acc += float((torch.argmax(torch.stack(outputs), dim=2).permute(1, 0) == labels).float().sum()) \
                           / (labels.shape[0] * labels.shape[1]) / types_prog / config.num_traj
                else:
                    acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \
                           / (labels.shape[0] * labels.shape[1]) / types_prog / config.num_traj
                loss_k_token = (
                    (losses_joint(outputs, labels, time_steps=k + 1) /
                     (k + 1)) / len(dataset_sizes.keys()) / config.num_traj)
                #loss_k = loss_k_token + perturb_loss
                loss_k = loss_k_token
                loss_k.backward()
                loss += loss_k.data
                loss_p += perturb_loss.data
                loss_t += loss_k_token.data
                del loss_k

        optimizer.step()
        train_loss += loss
        print(
            f"batch {batch_idx} train loss: {loss.cpu().numpy()}, token loss: {loss_t.cpu().numpy()}, perturb loss: {loss_p.cpu().numpy()}"
        )
        print(f"acc: {acc}")
Ejemplo n.º 7
0
def train_model(csgnet, train_dataset, val_dataset, max_epochs=None):
    if max_epochs is None:
        epochs = 100
    else:
        epochs = max_epochs

    optimizer = optim.Adam(
        [para for para in csgnet.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               lr_decay_epoch=3,
                               patience=config.patience)

    best_state_dict = None
    patience = 3
    prev_test_loss = 1e20
    prev_test_reward = 0
    num_worse = 0
    for epoch in range(100):
        train_loss = 0
        Accuracies = []
        csgnet.train()
        # Number of times to accumulate gradients
        num_accums = config.num_traj
        batch_idx = 0
        count = 0
        for batch in train_dataset:
            labels = np.stack([x[0] for x in batch])
            data = np.stack([x[1] for x in batch])
            if not len(labels) == config.batch_size:
                continue
            optimizer.zero_grad()
            loss_sum = Variable(torch.zeros(1)).cuda().data

            one_hot_labels = prepare_input_op(labels, len(unique_draws))
            one_hot_labels = Variable(torch.from_numpy(one_hot_labels)).cuda()
            data = Variable(
                torch.from_numpy(data)).cuda().unsqueeze(-1).float()
            labels = Variable(torch.from_numpy(labels)).cuda()

            # forward pass
            outputs = csgnet.forward2([data, one_hot_labels, max_len])

            loss = losses_joint(outputs, labels,
                                time_steps=max_len + 1) / num_accums
            loss.backward()
            loss_sum += loss.data

            batch_idx += 1
            count += len(data)

            if batch_idx % num_accums == 0:
                # Clip the gradient to fixed value to stabilize training.
                torch.nn.utils.clip_grad_norm_(csgnet.parameters(), 20)
                optimizer.step()
                l = loss_sum
                train_loss += l
                # print(f'train loss batch {batch_idx}: {l}')

        mean_train_loss = (train_loss * num_accums) / (count //
                                                       config.batch_size)
        print(f'train loss epoch {epoch}: {float(mean_train_loss)}')
        del data, loss, loss_sum, train_loss, outputs

        test_losses = 0
        acc = 0
        csgnet.eval()
        test_reward = 0
        batch_idx = 0
        count = 0
        for batch in val_dataset:
            labels = np.stack([x[0] for x in batch])
            data = np.stack([x[1] for x in batch])
            if not len(labels) == config.batch_size:
                continue
            parser = ParseModelOutput(unique_draws,
                                      stack_size=(max_len + 1) // 2 + 1,
                                      steps=max_len,
                                      canvas_shape=[64, 64, 64],
                                      primitives=primitives)
            with torch.no_grad():
                one_hot_labels = prepare_input_op(labels, len(unique_draws))
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).cuda()
                data = Variable(
                    torch.from_numpy(data)).cuda().unsqueeze(-1).float()
                labels = Variable(torch.from_numpy(labels)).cuda()

                test_output = csgnet.forward2([data, one_hot_labels, max_len])

                l = losses_joint(test_output, labels,
                                 time_steps=max_len + 1).data
                test_losses += l
                acc += float((torch.argmax(torch.stack(test_output), dim=2).permute(1, 0) == labels).float().sum()) \
                / (labels.shape[0] * labels.shape[1])

                test_output = csgnet.test2(data, max_len)

                stack, _, _ = parser.get_final_canvas(
                    test_output,
                    if_pred_images=True,
                    if_just_expressions=False)
                data_ = data.squeeze().cpu().numpy()
                R = np.sum(np.logical_and(stack, data_),
                           (1, 2, 3)) / (np.sum(np.logical_or(stack, data_),
                                                (1, 2, 3)) + 1)
                test_reward += np.sum(R)

            batch_idx += 1
            count += len(data)

        test_reward = test_reward / count

        test_loss = test_losses / (count // config.batch_size)
        acc = acc / (count // config.batch_size)

        if test_loss < prev_test_loss:
            prev_test_loss = test_loss
            best_state_dict = csgnet.state_dict()
            num_worse = 0
        else:
            num_worse += 1
        if num_worse >= patience:
            csgnet.load_state_dict(best_state_dict)
            break

        print(f'test loss epoch {epoch}: {float(test_loss)}')
        print(f'test IOU epoch {epoch}: {test_reward}')
        print(f'test acc epoch {epoch}: {acc}')
        if config.if_schedule:
            reduce_plat.reduce_on_plateu(-test_reward)

        del test_losses, test_output
        if test_reward > prev_test_reward:
            prev_test_reward = test_reward