Exemple #1
0
def test():
    t = time.time()
    nll_test = []
    nll_var_test = []

    mse_1_test = []
    mse_10_test = []
    mse_20_test = []

    kl_test = []
    kl_list_test = []
    kl_var_list_test = []

    acc_test = []
    acc_var_test = []
    acc_blocks_test = []
    acc_var_blocks_test = []
    perm_test = []

    KLb_test = []
    KLb_blocks_test = []  # KL between blocks list

    nll_M_test = []
    nll_M_var_test = []

    encoder.eval()
    decoder.eval()
    if not args.cuda:
        encoder.load_state_dict(torch.load(encoder_file, map_location='cpu'))
        decoder.load_state_dict(torch.load(decoder_file, map_location='cpu'))
    else:
        encoder.load_state_dict(torch.load(encoder_file))
        decoder.load_state_dict(torch.load(decoder_file))

    for batch_idx, (data, relations) in enumerate(test_loader):
        with torch.no_grad():
            if args.cuda:
                data, relations = data.cuda(), relations.cuda()

            assert (data.size(2) - args.timesteps) >= args.timesteps
            data_encoder = data[:, :, :args.timesteps, :].contiguous()
            data_decoder = data[:, :, -args.timesteps:, :].contiguous()

            # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles
            logits = encoder(data_encoder, rel_rec, rel_send)

            if args.NRI:
                edges = gumbel_softmax(logits, tau=args.temp, hard=args.hard)
                prob = my_softmax(logits, -1)

                loss_kl = kl_categorical_uniform(prob, args.num_atoms,
                                                 edge_types)
                loss_kl_split = [loss_kl]
                loss_kl_var_split = [
                    kl_categorical_uniform_var(prob, args.num_atoms,
                                               edge_types)
                ]

                KLb_test.append(0)
                KLb_blocks_test.append([0])

                acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_NRI(
                    logits, relations, args.edge_types_list)

            else:
                logits_split = torch.split(logits,
                                           args.edge_types_list,
                                           dim=-1)
                edges_split = tuple([
                    gumbel_softmax(logits_i, tau=args.temp, hard=args.hard)
                    for logits_i in logits_split
                ])
                edges = torch.cat(edges_split, dim=-1)
                prob_split = [
                    my_softmax(logits_i, -1) for logits_i in logits_split
                ]

                if args.prior:
                    loss_kl_split = [
                        kl_categorical(prob_split[type_idx],
                                       log_prior[type_idx], args.num_atoms)
                        for type_idx in range(len(args.edge_types_list))
                    ]
                    loss_kl = sum(loss_kl_split)
                else:
                    loss_kl_split = [
                        kl_categorical_uniform(prob_split[type_idx],
                                               args.num_atoms,
                                               args.edge_types_list[type_idx])
                        for type_idx in range(len(args.edge_types_list))
                    ]
                    loss_kl = sum(loss_kl_split)

                    loss_kl_var_split = [
                        kl_categorical_uniform_var(
                            prob_split[type_idx], args.num_atoms,
                            args.edge_types_list[type_idx])
                        for type_idx in range(len(args.edge_types_list))
                    ]

                acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_fNRI(
                    logits_split, relations, args.edge_types_list,
                    args.skip_first)

                KLb_blocks = KL_between_blocks(prob_split, args.num_atoms)
                KLb_test.append(sum(KLb_blocks).data.item())
                KLb_blocks_test.append([KL.data.item() for KL in KLb_blocks])

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, edges, rel_rec, rel_send, 1)

            if args.plot:
                import matplotlib.pyplot as plt
                output_plot = decoder(data_decoder, edges, rel_rec, rel_send,
                                      49)

                output_plot_en = decoder(data_encoder, edges, rel_rec,
                                         rel_send, 49)
                from trajectory_plot import draw_lines

                if args.NRI:
                    acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_NRI_batch(
                        logits, relations, args.edge_types_list)
                else:
                    acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_fNRI_batch(
                        logits_split, relations, args.edge_types_list)

                for i in range(args.batch_size):
                    fig = plt.figure(figsize=(7, 7))
                    ax = fig.add_axes([0, 0, 1, 1])
                    xmin_t, ymin_t, xmax_t, ymax_t = draw_lines(target,
                                                                i,
                                                                linestyle=':',
                                                                alpha=0.6)
                    xmin_o, ymin_o, xmax_o, ymax_o = draw_lines(
                        output_plot.detach().numpy(), i, linestyle='-')

                    ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)])
                    ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)])
                    ax.set_xticks([])
                    ax.set_yticks([])
                    block_names = [
                        str(j) for j in range(len(args.edge_types_list))
                    ]
                    acc_text = [
                        'layer ' + block_names[j] +
                        ' acc: {:02.0f}%'.format(100 * acc_blocks_batch[i, j])
                        for j in range(acc_blocks_batch.shape[1])
                    ]
                    acc_text = ', '.join(acc_text)
                    plt.text(0.5,
                             0.95,
                             acc_text,
                             horizontalalignment='center',
                             transform=ax.transAxes)
                    plt.show()

            loss_nll = nll_gaussian(
                output, target, args.var
            )  # compute the reconstruction loss. nll_gaussian is from utils.py
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            output_M = decoder(data_decoder, edges, rel_rec, rel_send,
                               args.prediction_steps)
            loss_nll_M = nll_gaussian(output_M, target, args.var)
            loss_nll_M_var = nll_gaussian_var(output_M, target, args.var)

            perm_test.append(perm)
            acc_test.append(acc_perm)
            acc_blocks_test.append(acc_blocks)
            acc_var_test.append(acc_var)
            acc_var_blocks_test.append(acc_var_blocks)

            output_10 = decoder(data_decoder, edges, rel_rec, rel_send, 10)
            output_20 = decoder(data_decoder, edges, rel_rec, rel_send, 20)
            mse_1_test.append(F.mse_loss(output, target).data.item())
            mse_10_test.append(F.mse_loss(output_10, target).data.item())
            mse_20_test.append(F.mse_loss(output_20, target).data.item())

            nll_test.append(loss_nll.data.item())
            kl_test.append(loss_kl.data.item())
            kl_list_test.append(
                [kl_loss.data.item() for kl_loss in loss_kl_split])

            nll_var_test.append(loss_nll_var.data.item())
            kl_var_list_test.append(
                [kl_var.data.item() for kl_var in loss_kl_var_split])

            nll_M_test.append(loss_nll_M.data.item())
            nll_M_var_test.append(loss_nll_M_var.data.item())

    print('--------------------------------')
    print('------------Testing-------------')
    print('--------------------------------')
    print('nll_test: {:.2f}'.format(np.mean(nll_test)),
          'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)),
          'kl_test: {:.5f}'.format(np.mean(kl_test)),
          'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)),
          'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)),
          'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)),
          'acc_test: {:.5f}'.format(np.mean(acc_test)),
          'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)),
          'KLb_test: {:.5f}'.format(np.mean(KLb_test)),
          'time: {:.1f}s'.format(time.time() - t))
    print(
        'acc_b_test: ' +
        str(np.around(np.mean(np.array(acc_blocks_test), axis=0), 4)),
        'acc_var_b: ' +
        str(np.around(np.mean(np.array(acc_var_blocks_test), axis=0), 4)),
        'kl_test: ' +
        str(np.around(np.mean(np.array(kl_list_test), axis=0), 4)))
    if args.save_folder:
        print('--------------------------------', file=log)
        print('------------Testing-------------', file=log)
        print('--------------------------------', file=log)
        print('nll_test: {:.2f}'.format(np.mean(nll_test)),
              'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)),
              'kl_test: {:.5f}'.format(np.mean(kl_test)),
              'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)),
              'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)),
              'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)),
              'acc_test: {:.5f}'.format(np.mean(acc_test)),
              'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)),
              'KLb_test: {:.5f}'.format(np.mean(KLb_test)),
              'time: {:.1f}s'.format(time.time() - t),
              file=log)
        print(
            'acc_b_test: ' +
            str(np.around(np.mean(np.array(acc_blocks_test), axis=0), 4)),
            'acc_var_b_test: ' +
            str(np.around(np.mean(np.array(acc_var_blocks_test), axis=0), 4)),
            'kl_test: ' +
            str(np.around(np.mean(np.array(kl_list_test), axis=0), 4)),
            file=log)
        log.flush()
Exemple #2
0
def train(epoch, best_val_loss):
    t = time.time()
    nll_train = []
    nll_var_train = []
    mse_train = []

    kl_train = []
    kl_list_train = []
    kl_var_list_train = []

    acc_train = []
    acc_var_train = []
    perm_train = []
    acc_var_blocks_train = []
    acc_blocks_train = []

    KLb_train = []
    KLb_blocks_train = []

    encoder.train()
    decoder.train()
    scheduler.step()
    if not args.plot:
        for batch_idx, (data, relations) in enumerate(
                train_loader
        ):  # relations are the ground truth interactions graphs
            if args.cuda:
                data, relations = data.cuda(), relations.cuda()
            data, relations = Variable(data), Variable(relations)

            if args.dont_split_data:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, :args.timesteps, :].contiguous()
            elif args.split_enc_only:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data
            else:
                assert (data.size(2) - args.timesteps) >= args.timesteps
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, -args.timesteps:, :].contiguous()

            optimizer.zero_grad()

            logits = encoder(data_encoder, rel_rec, rel_send)

            if args.NRI:
                # dim of logits, edges and prob are [batchsize, N^2-N, edgetypes] where N = no. of particles
                edges = gumbel_softmax(logits, tau=args.temp, hard=args.hard)
                prob = my_softmax(logits, -1)

                loss_kl = kl_categorical_uniform(prob, args.num_atoms,
                                                 edge_types)
                loss_kl_split = [loss_kl]
                loss_kl_var_split = [
                    kl_categorical_uniform_var(prob, args.num_atoms,
                                               edge_types)
                ]

                KLb_train.append(0)
                KLb_blocks_train.append([0])

                if args.no_edge_acc:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array(
                        [0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(
                            len(args.edge_types_list))
                else:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_NRI(
                        logits, relations, args.edge_types_list)

            else:
                # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles
                logits_split = torch.split(logits,
                                           args.edge_types_list,
                                           dim=-1)
                edges_split = tuple([
                    gumbel_softmax(logits_i, tau=args.temp, hard=args.hard)
                    for logits_i in logits_split
                ])
                edges = torch.cat(edges_split, dim=-1)
                prob_split = [
                    my_softmax(logits_i, -1) for logits_i in logits_split
                ]

                if args.prior:
                    loss_kl_split = [
                        kl_categorical(prob_split[type_idx],
                                       log_prior[type_idx], args.num_atoms)
                        for type_idx in range(len(args.edge_types_list))
                    ]
                    loss_kl = sum(loss_kl_split)
                else:
                    loss_kl_split = [
                        kl_categorical_uniform(prob_split[type_idx],
                                               args.num_atoms,
                                               args.edge_types_list[type_idx])
                        for type_idx in range(len(args.edge_types_list))
                    ]
                    loss_kl = sum(loss_kl_split)

                    loss_kl_var_split = [
                        kl_categorical_uniform_var(
                            prob_split[type_idx], args.num_atoms,
                            args.edge_types_list[type_idx])
                        for type_idx in range(len(args.edge_types_list))
                    ]

                if args.no_edge_acc:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array(
                        [0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(
                            len(args.edge_types_list))
                else:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_fNRI(
                        logits_split, relations, args.edge_types_list,
                        args.skip_first)

                KLb_blocks = KL_between_blocks(prob_split, args.num_atoms)
                KLb_train.append(sum(KLb_blocks).data.item())
                KLb_blocks_train.append([KL.data.item() for KL in KLb_blocks])

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, edges, rel_rec, rel_send,
                             args.prediction_steps)

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            if args.mse_loss:
                loss = F.mse_loss(output, target)
            else:
                loss = loss_nll
                if not math.isclose(args.beta, 0, rel_tol=1e-6):
                    loss += args.beta * loss_kl

            perm_train.append(perm)
            acc_train.append(acc_perm)
            acc_blocks_train.append(acc_blocks)
            acc_var_train.append(acc_var)
            acc_var_blocks_train.append(acc_var_blocks)

            loss.backward()
            optimizer.step()

            mse_train.append(F.mse_loss(output, target).data.item())
            nll_train.append(loss_nll.data.item())
            kl_train.append(loss_kl.data.item())
            kl_list_train.append([kl.data.item() for kl in loss_kl_split])

            nll_var_train.append(loss_nll_var.data.item())
            kl_var_list_train.append(
                [kl_var.data.item() for kl_var in loss_kl_var_split])

    nll_val = []
    nll_var_val = []
    mse_val = []

    kl_val = []
    kl_list_val = []
    kl_var_list_val = []

    acc_val = []
    acc_var_val = []
    acc_blocks_val = []
    acc_var_blocks_val = []
    perm_val = []

    KLb_val = []
    KLb_blocks_val = []  # KL between blocks list

    nll_M_val = []
    nll_M_var_val = []

    encoder.eval()
    decoder.eval()
    for batch_idx, (data, relations) in enumerate(valid_loader):
        with torch.no_grad():
            if args.cuda:
                data, relations = data.cuda(), relations.cuda()

            if args.dont_split_data:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, :args.timesteps, :].contiguous()
            elif args.split_enc_only:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data
            else:
                assert (data.size(2) - args.timesteps) >= args.timesteps
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, -args.timesteps:, :].contiguous()

            # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles
            logits = encoder(data_encoder, rel_rec, rel_send)

            if args.NRI:
                # dim of logits, edges and prob are [batchsize, N^2-N, edgetypes] where N = no. of particles
                edges = gumbel_softmax(
                    logits, tau=args.temp, hard=args.hard
                )  # uses concrete distribution (for hard=False) to sample edge types
                prob = my_softmax(
                    logits,
                    -1)  # my_softmax returns the softmax over the edgetype dim

                loss_kl = kl_categorical_uniform(prob, args.num_atoms,
                                                 edge_types)
                loss_kl_split = [loss_kl]
                loss_kl_var_split = [
                    kl_categorical_uniform_var(prob, args.num_atoms,
                                               edge_types)
                ]

                KLb_val.append(0)
                KLb_blocks_val.append([0])

                if args.no_edge_acc:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array(
                        [0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(
                            len(args.edge_types_list))
                else:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_NRI(
                        logits, relations, args.edge_types_list)

            else:
                # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles
                logits_split = torch.split(logits,
                                           args.edge_types_list,
                                           dim=-1)
                edges_split = tuple([
                    gumbel_softmax(logits_i, tau=args.temp, hard=args.hard)
                    for logits_i in logits_split
                ])
                edges = torch.cat(edges_split, dim=-1)
                prob_split = [
                    my_softmax(logits_i, -1) for logits_i in logits_split
                ]

                if args.prior:
                    loss_kl_split = [
                        kl_categorical(prob_split[type_idx],
                                       log_prior[type_idx], args.num_atoms)
                        for type_idx in range(len(args.edge_types_list))
                    ]
                    loss_kl = sum(loss_kl_split)
                else:
                    loss_kl_split = [
                        kl_categorical_uniform(prob_split[type_idx],
                                               args.num_atoms,
                                               args.edge_types_list[type_idx])
                        for type_idx in range(len(args.edge_types_list))
                    ]
                    loss_kl = sum(loss_kl_split)

                    loss_kl_var_split = [
                        kl_categorical_uniform_var(
                            prob_split[type_idx], args.num_atoms,
                            args.edge_types_list[type_idx])
                        for type_idx in range(len(args.edge_types_list))
                    ]

                if args.no_edge_acc:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = 0, np.array(
                        [0]), np.zeros(len(args.edge_types_list)), 0, np.zeros(
                            len(args.edge_types_list))
                else:
                    acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_fNRI(
                        logits_split, relations, args.edge_types_list,
                        args.skip_first)

                KLb_blocks = KL_between_blocks(prob_split, args.num_atoms)
                KLb_val.append(sum(KLb_blocks).data.item())
                KLb_blocks_val.append([KL.data.item() for KL in KLb_blocks])

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, edges, rel_rec, rel_send, 1)

            if args.plot:
                import matplotlib.pyplot as plt
                output_plot = decoder(data_decoder, edges, rel_rec, rel_send,
                                      49)

                if args.NRI:
                    acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_NRI_batch(
                        logits, relations, args.edge_types_list)
                else:
                    acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_fNRI_batch(
                        logits_split, relations, args.edge_types_list)

                from trajectory_plot import draw_lines
                for i in range(args.batch_size):
                    fig = plt.figure(figsize=(7, 7))
                    ax = fig.add_axes([0, 0, 1, 1])
                    xmin_t, ymin_t, xmax_t, ymax_t = draw_lines(target,
                                                                i,
                                                                linestyle=':',
                                                                alpha=0.6)
                    xmin_o, ymin_o, xmax_o, ymax_o = draw_lines(
                        output_plot.detach().numpy(), i, linestyle='-')

                    ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)])
                    ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)])
                    ax.set_xticks([])
                    ax.set_yticks([])
                    block_names = [
                        'layer ' + str(j)
                        for j in range(len(args.edge_types_list))
                    ]
                    #block_names = [ 'springs', 'charges' ]
                    acc_text = [
                        block_names[j] +
                        ' acc: {:02.0f}%'.format(100 * acc_blocks_batch[i, j])
                        for j in range(acc_blocks_batch.shape[1])
                    ]
                    acc_text = ', '.join(acc_text)
                    plt.text(0.5,
                             0.95,
                             acc_text,
                             horizontalalignment='center',
                             transform=ax.transAxes)
                    #plt.savefig(os.path.join(args.load_folder,str(i)+'_pred_and_true.png'), dpi=300)
                    plt.show()

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            output_M = decoder(data_decoder, edges, rel_rec, rel_send,
                               args.prediction_steps)
            loss_nll_M = nll_gaussian(output_M, target, args.var)
            loss_nll_M_var = nll_gaussian_var(output_M, target, args.var)

            perm_val.append(perm)
            acc_val.append(acc_perm)
            acc_blocks_val.append(acc_blocks)
            acc_var_val.append(acc_var)
            acc_var_blocks_val.append(acc_var_blocks)

            mse_val.append(F.mse_loss(output_M, target).data.item())
            nll_val.append(loss_nll.data.item())
            nll_var_val.append(loss_nll_var.data.item())

            kl_val.append(loss_kl.data.item())
            kl_list_val.append(
                [kl_loss.data.item() for kl_loss in loss_kl_split])
            kl_var_list_val.append(
                [kl_var.data.item() for kl_var in loss_kl_var_split])

            nll_M_val.append(loss_nll_M.data.item())
            nll_M_var_val.append(loss_nll_M_var.data.item())

    print(
        'Epoch: {:03d}'.format(epoch),
        'perm_val: ' + str(np.around(np.mean(np.array(perm_val), axis=0), 4)),
        'time: {:.1f}s'.format(time.time() - t))
    print('nll_trn: {:.2f}'.format(np.mean(nll_train)),
          'kl_trn: {:.5f}'.format(np.mean(kl_train)),
          'mse_trn: {:.10f}'.format(np.mean(mse_train)),
          'acc_trn: {:.5f}'.format(np.mean(acc_train)),
          'KLb_trn: {:.5f}'.format(np.mean(KLb_train)))
    print(
        'acc_b_trn: ' +
        str(np.around(np.mean(np.array(acc_blocks_train), axis=0), 4)),
        'kl_trn: ' +
        str(np.around(np.mean(np.array(kl_list_train), axis=0), 4)))
    print('nll_val: {:.2f}'.format(np.mean(nll_M_val)),
          'kl_val: {:.5f}'.format(np.mean(kl_val)),
          'mse_val: {:.10f}'.format(np.mean(mse_val)),
          'acc_val: {:.5f}'.format(np.mean(acc_val)),
          'KLb_val: {:.5f}'.format(np.mean(KLb_val)))
    print(
        'acc_b_val: ' +
        str(np.around(np.mean(np.array(acc_blocks_val), axis=0), 4)),
        'kl_val: ' + str(np.around(np.mean(np.array(kl_list_val), axis=0), 4)))
    print('Epoch: {:04d}'.format(epoch),
          'perm_val: ' +
          str(np.around(np.mean(np.array(perm_val), axis=0), 4)),
          'time: {:.4f}s'.format(time.time() - t),
          file=log)
    print('nll_trn: {:.5f}'.format(np.mean(nll_train)),
          'kl_trn: {:.5f}'.format(np.mean(kl_train)),
          'mse_trn: {:.10f}'.format(np.mean(mse_train)),
          'acc_trn: {:.5f}'.format(np.mean(acc_train)),
          'KLb_trn: {:.5f}'.format(np.mean(KLb_train)),
          'acc_b_trn: ' +
          str(np.around(np.mean(np.array(acc_blocks_train), axis=0), 4)),
          'kl_trn: ' +
          str(np.around(np.mean(np.array(kl_list_train), axis=0), 4)),
          file=log)
    print('nll_val: {:.5f}'.format(np.mean(nll_M_val)),
          'kl_val: {:.5f}'.format(np.mean(kl_val)),
          'mse_val: {:.10f}'.format(np.mean(mse_val)),
          'acc_val: {:.5f}'.format(np.mean(acc_val)),
          'KLb_val: {:.5f}'.format(np.mean(KLb_val)),
          'acc_b_val: ' +
          str(np.around(np.mean(np.array(acc_blocks_val), axis=0), 4)),
          'kl_val: ' +
          str(np.around(np.mean(np.array(kl_list_val), axis=0), 4)),
          file=log)
    if epoch == 0:
        labels = [
            'epoch', 'nll trn', 'kl trn', 'mse train', 'KLb trn', 'acc trn'
        ]
        labels += [
            'b' + str(i) + ' acc trn' for i in range(len(args.edge_types_list))
        ] + ['nll var trn']
        labels += [
            'b' + str(i) + ' kl trn' for i in range(len(kl_list_train[0]))
        ]
        labels += [
            'b' + str(i) + ' kl var trn' for i in range(len(kl_list_train[0]))
        ]
        labels += ['acc var trn'] + [
            'b' + str(i) + ' acc var trn'
            for i in range(len(args.edge_types_list))
        ]
        labels += [
            'nll val', 'nll_M_val', 'kl val', 'mse val', 'KLb val', 'acc val'
        ]
        labels += [
            'b' + str(i) + ' acc val' for i in range(len(args.edge_types_list))
        ]
        labels += ['nll var val', 'nll_M var val']
        labels += [
            'b' + str(i) + ' kl val' for i in range(len(kl_list_val[0]))
        ]
        labels += [
            'b' + str(i) + ' kl var val' for i in range(len(kl_list_val[0]))
        ]
        labels += ['acc var val'] + [
            'b' + str(i) + ' acc var val'
            for i in range(len(args.edge_types_list))
        ]
        csv_writer.writerow(labels)

        labels = ['trn ' + str(i) for i in range(len(perm_train[0]))]
        labels += ['val ' + str(i) for i in range(len(perm_val[0]))]
        perm_writer.writerow(labels)

    csv_writer.writerow(
        [
            epoch,
            np.mean(nll_train),
            np.mean(kl_train),
            np.mean(mse_train),
            np.mean(KLb_train),
            np.mean(acc_train)
        ] + list(np.mean(np.array(acc_blocks_train), axis=0)) +
        [np.mean(nll_var_train)] +
        list(np.mean(np.array(kl_list_train), axis=0)) +
        list(np.mean(np.array(kl_var_list_train), axis=0)) +
        #list(np.mean(np.array(KLb_blocks_train),axis=0)) +
        [np.mean(acc_var_train)] +
        list(np.mean(np.array(acc_var_blocks_train), axis=0)) + [
            np.mean(nll_val),
            np.mean(nll_M_val),
            np.mean(kl_val),
            np.mean(mse_val),
            np.mean(KLb_val),
            np.mean(acc_val)
        ] + list(np.mean(np.array(acc_blocks_val), axis=0)) +
        [np.mean(nll_var_val), np.mean(nll_M_var_val)] +
        list(np.mean(np.array(kl_list_val), axis=0)) +
        list(np.mean(np.array(kl_var_list_val), axis=0)) +
        #list(np.mean(np.array(KLb_blocks_val),axis=0))
        [np.mean(acc_var_val)] +
        list(np.mean(np.array(acc_var_blocks_val), axis=0)))
    perm_writer.writerow(
        list(np.mean(np.array(perm_train), axis=0)) +
        list(np.mean(np.array(perm_val), axis=0)))

    log.flush()
    if args.save_folder and np.mean(nll_M_val) < best_val_loss:
        torch.save(encoder.state_dict(), encoder_file)
        torch.save(decoder.state_dict(), decoder_file)
        print('Best model so far, saving...')
    return np.mean(nll_M_val)
def test():
    t = time.time()
    nll_test = []
    nll_var_test = []
    mse_1_test = []
    mse_10_test = []
    mse_20_test = []
    mse_static = []

    nll_M_test = []
    nll_M_var_test = []

    decoder.eval()
    if not args.cuda:
        decoder.load_state_dict(torch.load(decoder_file, map_location='cpu'))
    else:
        decoder.load_state_dict(torch.load(decoder_file))

    for batch_idx, (data, relations) in enumerate(test_loader):
        with torch.no_grad():

            if args.full_graph:
                zeros = torch.zeros([data.size(0), rel_rec.size(0)])
                ones = torch.ones([data.size(0), rel_rec.size(0)])
                if args.NRI:
                    stack = [ones] + [zeros for _ in range(edge_types - 1)]
                    rel_type_onehot = torch.stack(stack, -1)
                elif args.sigmoid:
                    stack = [ones for _ in range(args.num_factors)]
                    rel_type_onehot = torch.stack(stack, -1)
                else:
                    stack = []
                    for i in range(len(args.edge_types_list)):
                        stack += [ones] + [
                            zeros for _ in range(args.edge_types_list[i] - 1)
                        ]
                    rel_type_onehot = torch.stack(stack, -1)

            else:
                if args.NRI:
                    rel_type_onehot = torch.FloatTensor(
                        data.size(0), rel_rec.size(0), edge_types)
                    rel_type_onehot.zero_()
                    rel_type_onehot.scatter_(
                        2, relations.view(data.size(0), -1, 1), 1)
                elif args.sigmoid:
                    rel_type_onehot = relations.transpose(1, 2).type(
                        torch.FloatTensor)
                else:
                    rel_type_onehot = [
                        torch.FloatTensor(data.size(0), rel_rec.size(0), types)
                        for types in args.edge_types_list
                    ]
                    rel_type_onehot = [rel.zero_() for rel in rel_type_onehot]
                    rel_type_onehot = [
                        rel_type_onehot[i].scatter_(
                            2, relations[:, i, :].view(data.size(0), -1, 1), 1)
                        for i in range(len(rel_type_onehot))
                    ]
                    rel_type_onehot = torch.cat(rel_type_onehot, dim=-1)

            data_decoder = data[:, :, -args.timesteps:, :]

            if args.cuda:
                data_decoder, rel_type_onehot = data_decoder.cuda(
                ), rel_type_onehot.cuda()
            data_decoder = data_decoder.contiguous()

            data_decoder, rel_type_onehot = Variable(data_decoder), Variable(
                rel_type_onehot)

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send,
                             1)

            if args.plot:
                import matplotlib.pyplot as plt
                output_plot = decoder(data_decoder, rel_type_onehot, rel_rec,
                                      rel_send, 49)
                from trajectory_plot import draw_lines
                for i in range(args.batch_size):
                    fig = plt.figure(figsize=(7, 7))
                    ax = fig.add_axes([0, 0, 1, 1])
                    xmin_t, ymin_t, xmax_t, ymax_t = draw_lines(target,
                                                                i,
                                                                linestyle=':',
                                                                alpha=0.6)
                    xmin_o, ymin_o, xmax_o, ymax_o = draw_lines(
                        output_plot.detach().numpy(), i, linestyle='-')

                    ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)])
                    ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)])
                    ax.set_xticks([])
                    ax.set_yticks([])
                    #plt.savefig(os.path.join(args.load_folder,str(i)+'_pred_and_true_.png'), dpi=300)
                    plt.show()

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            output_M = decoder(data_decoder, rel_type_onehot, rel_rec,
                               rel_send, args.prediction_steps)
            loss_nll_M = nll_gaussian(output_M, target, args.var)

            output_10 = decoder(data_decoder, rel_type_onehot, rel_rec,
                                rel_send, 10)
            output_20 = decoder(data_decoder, rel_type_onehot, rel_rec,
                                rel_send, 20)
            mse_1_test.append(F.mse_loss(output, target).data.item())
            mse_10_test.append(F.mse_loss(output_10, target).data.item())
            mse_20_test.append(F.mse_loss(output_20, target).data.item())

            static = F.mse_loss(data_decoder[:, :, :-1, :],
                                data_decoder[:, :, 1:, :])
            mse_static.append(static.data.item())

            nll_test.append(loss_nll.data.item())
            nll_var_test.append(loss_nll_var.data.item())
            nll_M_test.append(loss_nll_M.data.item())

    print('--------------------------------')
    print('------------Testing-------------')
    print('--------------------------------')
    print('nll_test: {:.2f}'.format(np.mean(nll_test)),
          'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)),
          'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)),
          'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)),
          'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)),
          'mse_static: {:.10f}'.format(np.mean(mse_static)),
          'time: {:.1f}s'.format(time.time() - t))
    print('--------------------------------', file=log)
    print('------------Testing-------------', file=log)
    print('--------------------------------', file=log)
    print('nll_test: {:.2f}'.format(np.mean(nll_test)),
          'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)),
          'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)),
          'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)),
          'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)),
          'mse_static: {:.10f}'.format(np.mean(mse_static)),
          'time: {:.1f}s'.format(time.time() - t),
          file=log)
    log.flush()
def train(epoch, best_val_loss):
    t = time.time()
    nll_train = []
    nll_var_train = []
    mse_train = []

    decoder.train()
    scheduler.step()
    if not args.plot:
        for batch_idx, (data, relations) in enumerate(
                train_loader
        ):  # relations are the ground truth interactions graphs

            optimizer.zero_grad()

            if args.full_graph:
                zeros = torch.zeros([data.size(0), rel_rec.size(0)])
                ones = torch.ones([data.size(0), rel_rec.size(0)])
                if args.NRI:
                    stack = [ones] + [zeros for _ in range(edge_types - 1)]
                    rel_type_onehot = torch.stack(stack, -1)
                elif args.sigmoid:
                    stack = [ones for _ in range(args.num_factors)]
                    rel_type_onehot = torch.stack(stack, -1)
                else:
                    stack = []
                    for i in range(len(args.edge_types_list)):
                        stack += [ones] + [
                            zeros for _ in range(args.edge_types_list[i] - 1)
                        ]
                    rel_type_onehot = torch.stack(stack, -1)

            else:
                if args.NRI:
                    rel_type_onehot = torch.FloatTensor(
                        data.size(0), rel_rec.size(0), edge_types)
                    rel_type_onehot.zero_()
                    rel_type_onehot.scatter_(
                        2, relations.view(data.size(0), -1, 1), 1)
                elif args.sigmoid:
                    rel_type_onehot = relations.transpose(1, 2).type(
                        torch.FloatTensor)
                else:
                    rel_type_onehot = [
                        torch.FloatTensor(data.size(0), rel_rec.size(0), types)
                        for types in args.edge_types_list
                    ]
                    rel_type_onehot = [rel.zero_() for rel in rel_type_onehot]
                    rel_type_onehot = [
                        rel_type_onehot[i].scatter_(
                            2, relations[:, i, :].view(data.size(0), -1, 1), 1)
                        for i in range(len(rel_type_onehot))
                    ]
                    rel_type_onehot = torch.cat(rel_type_onehot, dim=-1)

            if args.dont_split_data:
                data_decoder = data[:, :, :args.timesteps, :]
            elif args.split_enc_only:
                data_decoder = data
            else:
                assert (data.size(2) - args.timesteps) >= args.timesteps
                data_decoder = data[:, :, -args.timesteps:, :]

            if args.cuda:
                data_decoder, rel_type_onehot = data_decoder.cuda(
                ), rel_type_onehot.cuda()
            data_decoder = data_decoder.contiguous()

            data_decoder, rel_type_onehot = Variable(data_decoder), Variable(
                rel_type_onehot)

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send,
                             args.prediction_steps)

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            loss_nll.backward()
            optimizer.step()

            mse_train.append(F.mse_loss(output, target).data.item())
            nll_train.append(loss_nll.data.item())
            nll_var_train.append(loss_nll_var.data.item())

    nll_val = []
    nll_var_val = []
    mse_val = []

    nll_M_val = []
    nll_M_var_val = []

    decoder.eval()
    for batch_idx, (data, relations) in enumerate(valid_loader):
        with torch.no_grad():

            if args.full_graph:
                zeros = torch.zeros([data.size(0), rel_rec.size(0)])
                ones = torch.ones([data.size(0), rel_rec.size(0)])
                if args.NRI:
                    stack = [ones] + [zeros for _ in range(edge_types - 1)]
                    rel_type_onehot = torch.stack(stack, -1)
                elif args.sigmoid:
                    stack = [ones for _ in range(args.num_factors)]
                    rel_type_onehot = torch.stack(stack, -1)
                else:
                    stack = []
                    for i in range(len(args.edge_types_list)):
                        stack += [ones] + [
                            zeros for _ in range(args.edge_types_list[i] - 1)
                        ]
                    rel_type_onehot = torch.stack(stack, -1)

            else:
                if args.NRI:
                    rel_type_onehot = torch.FloatTensor(
                        data.size(0), rel_rec.size(0), edge_types)
                    rel_type_onehot.zero_()
                    rel_type_onehot.scatter_(
                        2, relations.view(data.size(0), -1, 1), 1)
                elif args.sigmoid:
                    rel_type_onehot = relations.transpose(1, 2).type(
                        torch.FloatTensor)
                else:
                    rel_type_onehot = [
                        torch.FloatTensor(data.size(0), rel_rec.size(0), types)
                        for types in args.edge_types_list
                    ]
                    rel_type_onehot = [rel.zero_() for rel in rel_type_onehot]
                    rel_type_onehot = [
                        rel_type_onehot[i].scatter_(
                            2, relations[:, i, :].view(data.size(0), -1, 1), 1)
                        for i in range(len(rel_type_onehot))
                    ]
                    rel_type_onehot = torch.cat(rel_type_onehot, dim=-1)

            if args.dont_split_data:
                data_decoder = data[:, :, :args.timesteps, :]
            elif args.split_enc_only:
                data_decoder = data
            else:
                assert (data.size(2) - args.timesteps) >= args.timesteps
                data_decoder = data[:, :, -args.timesteps:, :]

            if args.cuda:
                data_decoder, rel_type_onehot = data_decoder.cuda(
                ), rel_type_onehot.cuda()
            data_decoder = data_decoder.contiguous()

            data_decoder, rel_type_onehot = Variable(data_decoder), Variable(
                rel_type_onehot)

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, rel_type_onehot, rel_rec, rel_send,
                             1)

            if args.plot:
                import matplotlib.pyplot as plt
                output_plot = decoder(data_decoder, rel_type_onehot, rel_rec,
                                      rel_send, 49)

                from trajectory_plot import draw_lines
                for i in range(args.batch_size):
                    fig = plt.figure(figsize=(7, 7))
                    ax = fig.add_axes([0, 0, 1, 1])
                    xmin_t, ymin_t, xmax_t, ymax_t = draw_lines(target,
                                                                i,
                                                                linestyle=':',
                                                                alpha=0.6)
                    xmin_o, ymin_o, xmax_o, ymax_o = draw_lines(
                        output_plot.detach().numpy(), i, linestyle='-')

                    ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)])
                    ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)])
                    ax.set_xticks([])
                    ax.set_yticks([])
                    plt.show()

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            output_M = decoder(data_decoder, rel_type_onehot, rel_rec,
                               rel_send, args.prediction_steps)
            loss_nll_M = nll_gaussian(output_M, target, args.var)
            loss_nll_M_var = nll_gaussian_var(output_M, target, args.var)

            mse_val.append(F.mse_loss(output_M, target).data.item())
            nll_val.append(loss_nll.data.item())
            nll_var_val.append(loss_nll_var.data.item())

            nll_M_val.append(loss_nll_M.data.item())
            nll_M_var_val.append(loss_nll_M_var.data.item())

    print('Epoch: {:03d}'.format(epoch),
          'time: {:.1f}s'.format(time.time() - t),
          'nll_trn: {:.2f}'.format(np.mean(nll_train)),
          'mse_trn: {:.10f}'.format(np.mean(mse_train)),
          'nll_val: {:.2f}'.format(np.mean(nll_M_val)),
          'mse_val: {:.10f}'.format(np.mean(mse_val)))

    print('Epoch: {:03d}'.format(epoch),
          'time: {:.1f}s'.format(time.time() - t),
          'nll_trn: {:.2f}'.format(np.mean(nll_train)),
          'mse_trn: {:.10f}'.format(np.mean(mse_train)),
          'nll_val: {:.2f}'.format(np.mean(nll_M_val)),
          'mse_val: {:.10f}'.format(np.mean(mse_val)),
          file=log)

    if epoch == 0:
        labels = ['epoch', 'nll trn', 'mse train', 'nll var trn']
        labels += [
            'nll val', 'nll M val', 'mse val', 'nll var val', 'nll M var val'
        ]
        csv_writer.writerow(labels)

    csv_writer.writerow(
        [
            epoch,
            np.mean(nll_train),
            np.mean(mse_train),
            np.mean(nll_var_train)
        ] + [np.mean(nll_val),
             np.mean(nll_M_val),
             np.mean(mse_val)] +
        [np.mean(nll_var_val), np.mean(nll_M_var_val)])

    log.flush()
    if args.save_folder and np.mean(nll_M_val) < best_val_loss:
        torch.save(decoder.state_dict(), decoder_file)
        print('Best model so far, saving...')
    return np.mean(nll_M_val)
Exemple #5
0
def test():
    nll_test = []
    nll_var_test = []

    acc_test = []
    acc_blocks_test = []
    acc_var_test = []
    acc_var_blocks_test = []
    perm_test = []

    mse_1_test = []
    mse_10_test = []
    mse_20_test = []

    nll_M_test = []
    nll_M_var_test = []

    encoder.eval()
    decoder.eval()
    if not args.cuda:
        encoder.load_state_dict(torch.load(encoder_file, map_location='cpu'))
        decoder.load_state_dict(torch.load(decoder_file, map_location='cpu'))
    else:
        encoder.load_state_dict(torch.load(encoder_file))
        decoder.load_state_dict(torch.load(decoder_file))

    for batch_idx, (data, relations) in enumerate(test_loader):
        with torch.no_grad():
            if args.cuda:
                data, relations = data.cuda(), relations.cuda()

            assert (data.size(2) - args.timesteps) >= args.timesteps
            data_encoder = data[:, :, :args.timesteps, :].contiguous()
            data_decoder = data[:, :, -args.timesteps:, :].contiguous()

            # dim of logits, edges and prob are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles
            logits = encoder(data_encoder, rel_rec, rel_send)
            edges = edges = my_sigmoid(logits,
                                       hard=args.hard,
                                       sharpness=args.sigmoid_sharpness)

            acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_sigmoid(
                edges, relations)

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, edges, rel_rec, rel_send, 1)

            if args.plot:
                import matplotlib.pyplot as plt
                output_plot = decoder(data_decoder, edges, rel_rec, rel_send,
                                      49)

                output_plot_en = decoder(data_encoder, edges, rel_rec,
                                         rel_send, 49)
                from trajectory_plot import draw_lines

                acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_sigmoid_batch(
                    edges, relations)

                for i in range(args.batch_size):
                    fig = plt.figure(figsize=(7, 7))
                    ax = fig.add_axes([0, 0, 1, 1])
                    xmin_t, ymin_t, xmax_t, ymax_t = draw_lines(target,
                                                                i,
                                                                linestyle=':',
                                                                alpha=0.6)
                    xmin_o, ymin_o, xmax_o, ymax_o = draw_lines(
                        output_plot.detach().numpy(), i, linestyle='-')

                    ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)])
                    ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)])
                    ax.set_xticks([])
                    ax.set_yticks([])
                    block_names = [str(j) for j in range(args.num_factors)]
                    acc_text = [
                        'layer ' + block_names[j] +
                        ' acc: {:02.0f}%'.format(100 * acc_blocks_batch[i, j])
                        for j in range(acc_blocks_batch.shape[1])
                    ]
                    acc_text = ', '.join(acc_text)
                    plt.text(0.5,
                             0.95,
                             acc_text,
                             horizontalalignment='center',
                             transform=ax.transAxes)
                    #plt.savefig(os.path.join(args.load_folder,str(i)+'_pred_and_true_.png'), dpi=300)
                    plt.show()

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            output_10 = decoder(data_decoder, edges, rel_rec, rel_send, 10)
            output_20 = decoder(data_decoder, edges, rel_rec, rel_send, 20)
            mse_1_test.append(F.mse_loss(output, target).data.item())
            mse_10_test.append(F.mse_loss(output_10, target).data.item())
            mse_20_test.append(F.mse_loss(output_20, target).data.item())

            loss_nll_M = nll_gaussian(output_10, target, args.var)
            loss_nll_M_var = nll_gaussian_var(output_10, target, args.var)

            perm_test.append(perm)
            acc_test.append(acc_perm)
            acc_blocks_test.append(acc_blocks)
            acc_var_test.append(acc_var)
            acc_var_blocks_test.append(acc_var_blocks)

            nll_test.append(loss_nll.data.item())
            nll_var_test.append(loss_nll_var.data.item())
            nll_M_test.append(loss_nll_M.data.item())
            nll_M_var_test.append(loss_nll_M_var.data.item())

    print('--------------------------------')
    print('------------Testing-------------')
    print('--------------------------------')
    print(
        'nll_test: {:.2f}'.format(np.mean(nll_test)),
        'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)),
        'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)),
        'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)),
        'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)),
        'acc_test: {:.5f}'.format(np.mean(acc_test)),
        'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)), 'acc_b_test: ' +
        str(np.around(np.mean(np.array(acc_blocks_test), axis=0), 4)),
        'acc_var_b_test: ' +
        str(np.around(np.mean(np.array(acc_var_blocks_test), axis=0), 4)))
    print('--------------------------------', file=log)
    print('------------Testing-------------', file=log)
    print('--------------------------------', file=log)
    print('nll_test: {:.2f}'.format(np.mean(nll_test)),
          'nll_M_test: {:.2f}'.format(np.mean(nll_M_test)),
          'mse_1_test: {:.10f}'.format(np.mean(mse_1_test)),
          'mse_10_test: {:.10f}'.format(np.mean(mse_10_test)),
          'mse_20_test: {:.10f}'.format(np.mean(mse_20_test)),
          'acc_test: {:.5f}'.format(np.mean(acc_test)),
          'acc_var_test: {:.5f}'.format(np.mean(acc_var_test)),
          'acc_b_test: ' +
          str(np.around(np.mean(np.array(acc_blocks_test), axis=0), 4)),
          'acc_var_b_test: ' +
          str(np.around(np.mean(np.array(acc_var_blocks_test), axis=0), 4)),
          file=log)
    log.flush()
Exemple #6
0
def train(epoch, best_val_loss):
    t = time.time()
    nll_train = []
    nll_var_train = []
    mse_train = []

    kl_train = []
    kl_list_train = []
    kl_var_list_train = []

    acc_train = []
    perm_train = []
    acc_blocks_train = []
    acc_var_train = []
    acc_var_blocks_train = []

    KLb_train = []
    KLb_blocks_train = []

    encoder.train()
    decoder.train()
    scheduler.step()
    if not args.plot:
        for batch_idx, (data, relations) in enumerate(
                train_loader
        ):  # relations are the ground truth interactions graphs
            if args.cuda:
                data, relations = data.cuda(), relations.cuda()
            data, relations = Variable(data), Variable(relations)

            if args.dont_split_data:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, :args.timesteps, :].contiguous()
            elif args.split_enc_only:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data
            else:
                assert (data.size(2) - args.timesteps) >= args.timesteps
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, -args.timesteps:, :].contiguous()

            optimizer.zero_grad()

            logits = encoder(data_encoder, rel_rec, rel_send)

            # dim of logits, edges and prob are [batchsize, N^2-N, edgetypes] where N = no. of particles

            edges = my_sigmoid(logits,
                               hard=args.hard,
                               sharpness=args.sigmoid_sharpness)

            loss_kl = 0
            loss_kl_split = [0]
            loss_kl_var_split = [0]

            KLb_train.append(0)
            KLb_blocks_train.append([0])

            acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_sigmoid(
                edges, relations)

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, edges, rel_rec, rel_send,
                             args.prediction_steps)

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            loss = F.mse_loss(output, target)

            perm_train.append(perm)
            acc_train.append(acc_perm)
            acc_blocks_train.append(acc_blocks)
            acc_var_train.append(acc_var)
            acc_var_blocks_train.append(acc_var_blocks)

            loss.backward()
            optimizer.step()

            mse_train.append(loss.data.item())
            nll_train.append(loss_nll.data.item())
            nll_var_train.append(loss_nll_var.data.item())

    nll_val = []
    nll_var_val = []
    mse_val = []

    kl_val = []
    kl_list_val = []
    kl_var_list_val = []

    acc_val = []
    acc_blocks_val = []
    acc_var_val = []
    acc_var_blocks_val = []
    perm_val = []

    KLb_val = []
    KLb_blocks_val = []  # KL between blocks list

    nll_M_val = []
    nll_M_var_val = []

    encoder.eval()
    decoder.eval()
    for batch_idx, (data, relations) in enumerate(valid_loader):
        with torch.no_grad():
            if args.cuda:
                data, relations = data.cuda(), relations.cuda()

            if args.dont_split_data:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, :args.timesteps, :].contiguous()
            elif args.split_enc_only:
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data
            else:
                assert (data.size(2) - args.timesteps) >= args.timesteps
                data_encoder = data[:, :, :args.timesteps, :].contiguous()
                data_decoder = data[:, :, -args.timesteps:, :].contiguous()

            # dim of logits, edges are [batchsize, N^2-N, sum(edge_types_list)] where N = no. of particles
            logits = encoder(data_encoder, rel_rec, rel_send)

            edges = my_sigmoid(logits,
                               hard=args.hard,
                               sharpness=args.sigmoid_sharpness)

            loss_kl = 0
            loss_kl_split = [0]
            loss_kl_var_split = [0]

            KLb_train.append(0)
            KLb_blocks_train.append([0])

            acc_perm, perm, acc_blocks, acc_var, acc_var_blocks = edge_accuracy_perm_sigmoid(
                edges, relations)

            target = data_decoder[:, :,
                                  1:, :]  # dimensions are [batch, particle, time, state]
            output = decoder(data_decoder, edges, rel_rec, rel_send, 1)

            if args.plot:
                import matplotlib.pyplot as plt
                output_plot = decoder(data_decoder, edges, rel_rec, rel_send,
                                      49)

                acc_batch, perm, acc_blocks_batch = edge_accuracy_perm_sigmoid_batch(
                    edges, relations)

                from trajectory_plot import draw_lines
                for i in range(args.batch_size):
                    fig = plt.figure(figsize=(7, 7))
                    ax = fig.add_axes([0, 0, 1, 1])
                    xmin_t, ymin_t, xmax_t, ymax_t = draw_lines(target,
                                                                i,
                                                                linestyle=':',
                                                                alpha=0.6)
                    xmin_o, ymin_o, xmax_o, ymax_o = draw_lines(
                        output_plot.detach().numpy(), i, linestyle='-')

                    ax.set_xlim([min(xmin_t, xmin_o), max(xmax_t, xmax_o)])
                    ax.set_ylim([min(ymin_t, ymin_o), max(ymax_t, ymax_o)])
                    ax.set_xticks([])
                    ax.set_yticks([])
                    block_names = [str(j) for j in range(args.num_factors)]
                    acc_text = [
                        'layer ' + block_names[j] +
                        ' acc: {:02.0f}%'.format(100 * acc_blocks_batch[i, j])
                        for j in range(acc_blocks_batch.shape[1])
                    ]
                    acc_text = ', '.join(acc_text)
                    plt.text(0.5,
                             0.95,
                             acc_text,
                             horizontalalignment='center',
                             transform=ax.transAxes)
                    plt.show()

            loss_nll = nll_gaussian(output, target, args.var)
            loss_nll_var = nll_gaussian_var(output, target, args.var)

            output_M = decoder(data_decoder, edges, rel_rec, rel_send,
                               args.prediction_steps)
            loss_nll_M = nll_gaussian(output_M, target, args.var)
            loss_nll_M_var = nll_gaussian_var(output_M, target, args.var)

            perm_val.append(perm)
            acc_val.append(acc_perm)
            acc_blocks_val.append(acc_blocks)
            acc_var_val.append(acc_var)
            acc_var_blocks_val.append(acc_var_blocks)

            mse_val.append(F.mse_loss(output_M, target).data.item())
            nll_val.append(loss_nll.data.item())
            nll_var_val.append(loss_nll_var.data.item())

            nll_M_val.append(loss_nll_M.data.item())
            nll_M_var_val.append(loss_nll_M_var.data.item())

    print(
        'Epoch: {:03d}'.format(epoch),
        'perm_val: ' + str(np.around(np.mean(np.array(perm_val), axis=0), 4)),
        'time: {:.1f}s'.format(time.time() - t))
    print(
        'nll_trn: {:.2f}'.format(np.mean(nll_train)),
        'mse_trn: {:.10f}'.format(np.mean(mse_train)),
        'acc_trn: {:.5f}'.format(np.mean(acc_train)), 'acc_b_trn: ' +
        str(np.around(np.mean(np.array(acc_blocks_train), axis=0), 4)))
    print(
        'nll_val: {:.2f}'.format(np.mean(nll_M_val)),
        'mse_val: {:.10f}'.format(np.mean(mse_val)),
        'acc_val: {:.5f}'.format(np.mean(acc_val)), 'acc_b_val: ' +
        str(np.around(np.mean(np.array(acc_blocks_val), axis=0), 4)))
    print('Epoch: {:03d}'.format(epoch),
          'perm_val: ' +
          str(np.around(np.mean(np.array(perm_val), axis=0), 4)),
          'time: {:.1f}s'.format(time.time() - t),
          file=log)
    print('nll_trn: {:.2f}'.format(np.mean(nll_train)),
          'mse_trn: {:.10f}'.format(np.mean(mse_train)),
          'acc_trn: {:.5f}'.format(np.mean(acc_train)),
          'acc_b_trn: ' +
          str(np.around(np.mean(np.array(acc_blocks_train), axis=0), 4)),
          file=log)
    print('nll_val: {:.2f}'.format(np.mean(nll_val)),
          'nll_M_val: {:.2f}'.format(np.mean(nll_M_val)),
          'mse_val: {:.10f}'.format(np.mean(mse_val)),
          'acc_val: {:.5f}'.format(np.mean(acc_val)),
          'acc_b_val: ' +
          str(np.around(np.mean(np.array(acc_blocks_val), axis=0), 4)),
          file=log)
    if epoch == 0:
        labels = ['epoch', 'nll trn', 'mse train', 'nll var trn', 'acc trn']
        labels += ['b' + str(i) + ' acc trn' for i in range(args.num_factors)]
        labels += ['acc var trn'] + [
            'b' + str(i) + ' acc var trn' for i in range(args.num_factors)
        ]
        labels += ['nll val', 'nll M val', 'mse val', 'acc val']
        labels += ['b' + str(i) + ' acc val' for i in range(args.num_factors)]
        labels += ['nll var val', 'nll M var val']
        labels += ['acc var val'] + [
            'b' + str(i) + ' acc var val' for i in range(args.num_factors)
        ]
        csv_writer.writerow(labels)

        labels = ['trn ' + str(i) for i in range(len(perm_train[0]))]
        labels += ['val ' + str(i) for i in range(len(perm_val[0]))]
        perm_writer.writerow(labels)

    csv_writer.writerow(
        [
            epoch,
            np.mean(nll_train),
            np.mean(mse_train),
            np.mean(nll_var_train),
            np.mean(acc_train)
        ] + list(np.mean(np.array(acc_blocks_train), axis=0)) +
        [np.mean(acc_var_train)] +
        list(np.mean(np.array(acc_var_blocks_train), axis=0)) + [
            np.mean(nll_val),
            np.mean(nll_M_val),
            np.mean(mse_val),
            np.mean(acc_val)
        ] + list(np.mean(np.array(acc_blocks_val), axis=0)) +
        [np.mean(nll_var_val), np.mean(nll_M_var_val)] +
        [np.mean(acc_var_val)] +
        list(np.mean(np.array(acc_var_blocks_val), axis=0)))
    perm_writer.writerow(
        list(np.mean(np.array(perm_train), axis=0)) +
        list(np.mean(np.array(perm_val), axis=0)))

    log.flush()
    if args.save_folder and np.mean(nll_M_val) < best_val_loss:
        torch.save(encoder.state_dict(), encoder_file)
        torch.save(decoder.state_dict(), decoder_file)
        print('Best model so far, saving...')
    return np.mean(nll_M_val)