def main():
    # Train model
    best_val_loss = np.inf
    best_epoch = 0
    trajectory_len=19
    
    valid_loader, test_loader = load_AL_data(batch_size=args.batch_size,\
                                total_size=args.dataset_size,suffix=args.suffix)
                                                       
    func = rollout_sliding_cube
    simulator = ControlSimulator(func, trajectory_len, args.input_atoms, args.target_atoms, \
                                 low=1, high=3, control_low=0, control_high=5)
    uncertain_sampler = MaximalEntropySimulatorSampler(simulator)
#     random_sampler = RandomSimulatorSampler(simulator)
    
    logger = Logger(save_folder)
    
    for epoch in range(args.epochs): 
        if epoch == 0: 
            data=[]
            nodes=[]
            for i in range(args.input_atoms):
                new_data, uncertain_nodes = uncertain_sampler.sample(i, args.batch_size)
                data.append(new_data)
                nodes.append(uncertain_nodes)
            
            data = torch.cat(data)
            nodes = torch.cat(nodes)
            
            train_dataset = ALDataset(data, nodes)
            train_loader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle=False)
        else:
            control_node = uncertain_sampler.criterion(decoder.rel_graph)
            new_data, uncertain_nodes = uncertain_sampler.sample(control_node, args.batch_size) 
            train_dataset, train_loader = update_ALDataset(train_dataset, new_data, uncertain_nodes, args.batch_size)   
#         print(train_dataset.data[:,:,:3,0], train_dataset.nodes)
        #TODO: when len(train_dataset) reaches budget, force stop
        print('#batches in train_dataset', len(train_dataset)/args.batch_size)
        train_control(args, log_prior, logger, optimizer, save_folder, train_loader, epoch, decoder, \
                       rel_rec, rel_send, mask_grad=True)
        nll_val_loss = val_control(args, log_prior, logger, save_folder, valid_loader, epoch, decoder, rel_rec, rel_send)
        
        scheduler.step()
        if nll_val_loss < best_val_loss:
            best_val_loss = nll_val_loss
            best_epoch = epoch
 
        
    print("Optimization Finished!")
    print("Best Epoch: {:04d}".format(best_epoch))
    if args.save_folder:
        print("Best Epoch: {:04d}".format(best_epoch), file=log)
        log.flush()

    test_control(test_loader)
    if log is not None:
        print(save_folder)
        log.close()
示例#2
0
    def train_causal(self):
        """[summary]

            Args:
                action ([list]): the new trajectory setting.

            Returns:
                state: [learning_assess, obj_data_features]
                reward: - val_MSE
                done: Whether the data budget is met. If yes, the training can end early.
            """
        # GPUtil.showUtilization()
        # self.train_dataset.data size (batch_size, num_nodes, timesteps, feat_dims)
#         print('dataset size', len(self.train_dataset),
#               'last ten', self.train_dataset.data[-10:, :, 0, 0])
        train_data_loader = DataLoader(
            self.train_dataset, batch_size=self.args.train_bs, shuffle=False)
        for i in range(self.args.max_causal_epochs):
            print(str(i), "iter of epoch", self.epoch)
            nll, nll_lasttwo, kl, mse, control_constraint_loss, lr, rel_graphs, rel_graphs_grad, a, b, c, d, e = train_control(
                self.args, self.log_prior, self.causal_model_optimizer, self.save_folder, train_data_loader, self.causal_model, self.epoch)

            # val_dataset should be continuous, more coverage
            nll_val, nll_lasttwo_val, kl_val, mse_val, a_val, b_val, c_val, control_constraint_loss_val, nll_lasttwo_5_val, nll_lasttwo_10_val, nll_lasttwo__1_val, nll_lasttwo_1_val = val_control(
                self.args, self.log_prior, self.logger, self.save_folder, self.valid_data_loader, self.epoch, self.causal_model)
            # val_loss = 0
            self.scheduler.step()
            self.early_stop_monitor(nll_val)
            if self.early_stop_monitor.counter == self.args.patience:
                self.early_stop_monitor.counter = 0
                self.early_stop_monitor.best_score = None
                self.early_stop_monitor.stopped_epoch = i
                print("Early stopping", str(i), "iter of epoch", self.epoch)
                break

        self.logger.log('val', self.causal_model, self.epoch, nll_val, nll_lasttwo_val, kl_val=kl_val, mse_val=mse_val, a_val=a_val, b_val=b_val, c_val=c_val, control_constraint_loss_val=control_constraint_loss_val,
                        nll_lasttwo_5_val=nll_lasttwo_5_val,  nll_lasttwo_10_val=nll_lasttwo_10_val, nll_lasttwo__1_val=nll_lasttwo__1_val, nll_lasttwo_1_val=nll_lasttwo_1_val, scheduler=self.scheduler)

        if self.epoch % self.args.train_log_freq == 0:
            self.logger.log('train', self.causal_model, self.epoch, nll, nll_lasttwo, kl_train=kl, mse_train=mse, control_constraint_loss_train=control_constraint_loss, lr_train=lr, rel_graphs=rel_graphs,
                            rel_graphs_grad=rel_graphs_grad, msg_hook_weights_train=a, nll_lasttwo_5_train=b, nll_lasttwo_10_train=c, nll_lasttwo__1_train=d, nll_lasttwo_1_train=e)

        self.epoch += 1
        return nll_val
示例#3
0
def main():
    # Train model
    best_val_loss = np.inf
    best_epoch = 0

    # AL grouped must use controlled training dataset
    train_dataset = OneGraphDataset.load_one_graph_data(
        'train_causal_vel_' + args.suffix,
        train_data_min_max=None,
        size=args.train_size,
        self_loop=args.self_loop,
        control=args.grouped,
        control_nodes=args.input_atoms,
        variations=args.variations,
        need_grouping=args.need_grouping)

    if args.val_grouped:
        # To see control loss, val and test should be grouped
        valid_dataset = OneGraphDataset.load_one_graph_data(
            'valid_causal_vel_' + args.val_suffix,
            train_data_min_max=[train_dataset.mins, train_dataset.maxs],
            size=args.val_size,
            self_loop=args.self_loop,
            control=True,
            control_nodes=args.input_atoms,
            variations=args.val_variations,
            need_grouping=args.val_need_grouping)
        valid_sampler = RandomPytorchSampler(valid_dataset)
        valid_data_loader = DataLoader(valid_dataset,
                                       batch_size=args.val_bs,
                                       shuffle=False,
                                       sampler=valid_sampler)

        # test_data = load_one_graph_data(
        #     'test_'+args.val_suffix, size=args.test_size, self_loop=args.self_loop, control=True, control_nodes=args.input_atoms, variations=4)
        # test_sampler = RandomPytorchSampler(test_data)
        # test_data_loader = DataLoader(
        #     test_data, batch_size=args.val_bs, shuffle=False, sampler=test_sampler)
    else:
        valid_dataset = OneGraphDataset.load_one_graph_data(
            'valid_causal_vel_' + args.val_suffix,
            train_data_min_max=[train_dataset.mins, train_dataset.maxs],
            size=args.val_size,
            self_loop=args.self_loop,
            control=False)
        valid_data_loader = DataLoader(valid_dataset,
                                       batch_size=args.val_bs,
                                       shuffle=True)
        # test_data = load_one_graph_data(
        #     'test_'+args.val_suffix, size=args.val_size, self_loop=args.self_loop, control=False)
        # test_data_loader = DataLoader(
        #     test_data, batch_size=args.val_bs, shuffle=True)

    if args.sampler == 'random':
        sampler = RandomDatasetSampler(train_dataset, args)
    elif args.sampler == 'uncertainty':
        sampler = MaximalEntropyDatasetSampler(train_dataset, args)
    else:
        print('Only random and uncertainty samplers are supported for now!')

    logger = Logger(save_folder)

    print('Doing initial validation before training...')
    val_control(args, log_prior, logger, save_folder, valid_data_loader, -1,
                decoder, rel_rec, rel_send, scheduler)

    data_idx = []
    nodes = []
    # Start with one group of each node
    for i in range(args.input_atoms):
        # group_size=args.variations
        new_data_idx = sampler.sample([i], args.variations, 200)
        uncertain_nodes = torch.LongTensor([i] * args.variations * 200).cuda()
        data_idx.append(new_data_idx)
        nodes.append(uncertain_nodes)

    data_idx = torch.cat(data_idx)
    nodes = torch.cat(nodes)
    al_train_dataset = ALIndexDataset(train_dataset, data_idx, nodes)

    # while len(al_train_dataset.idxs) < args.budget:
    for epoch in range(500):
        # epoch = int(len(al_train_dataset)/args.log_data_size)
        # Sampler batch_size is fixed to be #value variations for a node
        # control_nodes = sampler.criterion(decoder.rel_graph, k=args.topk)
        # new_data_idx = sampler.sample(
        #     control_nodes, args.variations, args.sample_num_groups)
        # uncertain_nodes = torch.LongTensor(
        #     control_nodes).repeat_interleave(args.variations*args.sample_num_groups).cuda()
        # al_train_dataset.update(new_data_idx, uncertain_nodes)

        # with open(os.path.join(save_folder, str(epoch)+'_queries.pt'), 'wb') as f:
        #     torch.save(
        #         [al_train_dataset.nodes, al_train_dataset.idxs], f)
        #     print('sampled nodes this episode',
        #           control_nodes, new_data_idx, len(al_train_dataset))

        # Only need to be a normal dataloader since the grouping and indexing are already implemented in al_train_dataset. Just make sure shuffle=False.
        train_data_loader = DataLoader(al_train_dataset,
                                       batch_size=args.train_bs,
                                       shuffle=False)

        # for j in range(10):
        # print('epoch and j', epoch, j)
        nll, nll_lasttwo, kl, mse, control_constraint_loss, lr, rel_graphs, rel_graphs_grad, a, b, c, d, e, f = train_control(
            args, log_prior, optimizer, save_folder, train_data_loader,
            valid_data_loader, decoder, rel_rec, rel_send, epoch)

        if epoch % args.train_log_freq == 0:
            logger.log('train',
                       decoder,
                       epoch,
                       nll,
                       nll_lasttwo,
                       kl=kl,
                       mse=mse,
                       control_constraint_loss=control_constraint_loss,
                       lr=lr,
                       rel_graphs=rel_graphs,
                       rel_graphs_grad=rel_graphs_grad,
                       msg_hook_weights=a,
                       nll_train_lasttwo=b,
                       nll_train_lasttwo_5=c,
                       nll_train_lasttwo_10=d,
                       nll_train_lasttwo__1=e,
                       nll_train_lasttwo_1=f)

        if epoch % args.val_log_freq == 0:
            _ = val_control(args, log_prior, logger, save_folder,
                            valid_data_loader, epoch, decoder, rel_rec,
                            rel_send, scheduler)
        scheduler.step()

    print("Optimization Finished!")
    print("Best Epoch: {:04d}".format(logger.best_epoch))
    if args.save_folder:
        print("Best Epoch: {:04d}".format(logger.best_epoch), file=meta_file)
        meta_file.flush()

    test_control(test_data_loader)
    if meta_file is not None:
        print(save_folder)
        meta_file.close()
示例#4
0
    def reset(self, feature_extractors=True):  # data_num_per_obj=1):
        self.total_intervention = 0
        self.early_stop_monitor = EarlyStopping()
        # obj idx: obj attributes tensor
        self.obj = {
            i: self.discrete_mapping[0][i]
            for i in range(self.args.initial_obj_num)
        }
        # obj idx: tensor datapoints using that obj. Acts as the training pool.
        self.obj_data = {i: [] for i in range(self.obj_num)}
        self.train_dataset = RLDataset(torch.Tensor(), self.edge, self.mins,
                                       self.maxs)
        self.intervened_nodes = []

        self.causal_model = MLPDecoder_Causal(self.args, self.rel_rec,
                                              self.rel_send).cuda()
        self.causal_model_optimizer = optim.Adam(
            list(self.causal_model.parameters()) +
            [self.causal_model.rel_graph],
            lr=self.args.lr)
        self.scheduler = lr_scheduler.StepLR(self.causal_model_optimizer,
                                             step_size=self.args.lr_decay,
                                             gamma=self.args.gamma)

        self.init_train_data()
        load_weights = '100_warmup_weights.pt'
        if load_weights not in os.listdir(self.args.save_folder):
            print('no pretrained warm up weights, so training one now.')
            train_data_loader = DataLoader(self.train_dataset,
                                           batch_size=self.args.train_bs,
                                           shuffle=False)

            lowest_loss = np.inf
            for i in range(1000):
                print(str(i), 'of warm up training', self.args.save_folder,
                      lowest_loss)
                nll, nll_lasttwo, kl, mse, control_constraint_loss, lr, rel_graphs, rel_graphs_grad, a, b, c, d, e = train_control(
                    self.args, self.log_prior, self.causal_model_optimizer,
                    self.save_folder, train_data_loader, self.causal_model,
                    self.epoch)

                # val_dataset should be continuous, more coverage
                nll_val, nll_lasttwo_val, kl_val, mse_val, a_val, b_val, c_val, control_constraint_loss_val, nll_lasttwo_5_val, nll_lasttwo_10_val, nll_lasttwo__1_val, nll_lasttwo_1_val = val_control(
                    self.args, self.log_prior, self.logger, self.save_folder,
                    self.valid_data_loader, self.epoch, self.causal_model)
                if nll_val < lowest_loss:
                    print('new lowest_loss', nll_val)
                    lowest_loss = nll_val
                    torch.save([
                        self.causal_model.state_dict(),
                        self.causal_model.rel_graph
                    ], os.path.join(self.args.save_folder, load_weights))

                self.logger.log(
                    'val',
                    self.causal_model,
                    i,
                    nll_val,
                    nll_lasttwo_val,
                    kl_val=kl_val,
                    mse_val=mse_val,
                    a_val=a_val,
                    b_val=b_val,
                    c_val=c_val,
                    control_constraint_loss_val=control_constraint_loss_val,
                    nll_lasttwo_5_val=nll_lasttwo_5_val,
                    nll_lasttwo_10_val=nll_lasttwo_10_val,
                    nll_lasttwo__1_val=nll_lasttwo__1_val,
                    nll_lasttwo_1_val=nll_lasttwo_1_val,
                    scheduler=self.scheduler)

                self.logger.log(
                    'train',
                    self.causal_model,
                    i,
                    nll,
                    nll_lasttwo,
                    kl_train=kl,
                    mse_train=mse,
                    control_constraint_loss_train=control_constraint_loss,
                    lr_train=lr,
                    rel_graphs=rel_graphs,
                    rel_graphs_grad=rel_graphs_grad,
                    msg_hook_weights_train=a,
                    nll_lasttwo_5_train=b,
                    nll_lasttwo_10_train=c,
                    nll_lasttwo__1_train=d,
                    nll_lasttwo_1_train=e)

        else:
            weights, graph = torch.load(
                os.path.join(self.args.save_folder, load_weights))
            self.causal_model.load_state_dict(weights)
            self.causal_model.rel_graph = graph.cuda()
            print('warm up weights loaded.')

        self.epoch += 1
        if feature_extractors:
            # Make sure the output dim of both encoders are the same!
            self.obj_extractor = MLPEncoder(self.args, 3, 128,
                                            self.args.extract_feat_dim).cuda()
            self.obj_extractor_optimizer = optim.Adam(
                list(self.obj_extractor.parameters()),
                lr=self.args.obj_extractor_lr)

            # TODO: try self.obj_data_extractor = MLPEncoder(args, 3, 64, 16).cuda()
            # Bidirectional LSTM
            self.obj_data_extractor = LSTMEncoder(
                self.args.num_atoms,
                self.args.extract_feat_dim,
                num_direction=self.lstm_direction,
                batch_first=True).cuda()
            self.obj_data_extractor_optimizer = optim.Adam(
                list(self.obj_data_extractor.parameters()),
                lr=self.args.obj_data_extractor_lr)

            self.learning_assess_extractor = LSTMEncoder(
                8,
                self.args.extract_feat_dim,
                num_direction=self.lstm_direction,
                batch_first=True).cuda()
            self.learning_assess_extractor_optimizer = optim.Adam(
                list(self.learning_assess_extractor.parameters()),
                lr=self.args.learning_assess_extractor_lr)
def main():
    # Train model
    best_val_loss = np.inf
    best_epoch = 0

    if args.grouped:
        assert args.train_bs % args.variations == 0, "Grouping training set requires args.traing-bs integer times of args.variations"

        train_data = OneGraphDataset.load_one_graph_data(
            'train_causal_vel_' + args.suffix,
            train_data_min_max=None,
            size=args.train_size,
            self_loop=args.self_loop,
            control=args.grouped,
            control_nodes=args.input_atoms,
            variations=args.variations,
            need_grouping=args.need_grouping)
        train_sampler = RandomPytorchSampler(train_data)
        train_data_loader = DataLoader(train_data,
                                       batch_size=args.train_bs,
                                       shuffle=False,
                                       sampler=train_sampler)

    else:
        train_data = OneGraphDataset.load_one_graph_data(
            'train_causal_vel_' + args.suffix,
            train_data_min_max=None,
            size=args.train_size,
            self_loop=args.self_loop,
            control=args.grouped)
        train_data_loader = DataLoader(train_data,
                                       batch_size=args.train_bs,
                                       shuffle=True)

    if args.val_grouped:
        # To see control loss, val and test should be grouped
        valid_data = OneGraphDataset.load_one_graph_data(
            'valid_causal_vel_' + args.val_suffix,
            train_data_min_max=[train_data.mins, train_data.maxs],
            size=args.val_size,
            self_loop=args.self_loop,
            control=True,
            control_nodes=args.input_atoms,
            variations=args.val_variations,
            need_grouping=args.val_need_grouping)
        valid_sampler = RandomPytorchSampler(valid_data)
        valid_data_loader = DataLoader(valid_data,
                                       batch_size=args.val_bs,
                                       shuffle=False,
                                       sampler=valid_sampler)

        # test_data = load_one_graph_data(
        #     'test_'+args.val_suffix, size=args.test_size, self_loop=args.self_loop, control=True, control_nodes=args.input_atoms, variations=4)
        # test_sampler = RandomPytorchSampler(test_data)
        # test_data_loader = DataLoader(
        #     test_data, batch_size=args.val_bs, shuffle=False, sampler=test_sampler)
    else:
        valid_data = OneGraphDataset.load_one_graph_data(
            'valid_causal_vel_' + args.val_suffix,
            train_data_min_max=[train_data.mins, train_data.maxs],
            size=args.val_size,
            self_loop=args.self_loop,
            control=False)
        valid_data_loader = DataLoader(valid_data,
                                       batch_size=args.val_bs,
                                       shuffle=True)
        # test_data = load_one_graph_data(
        #     'test_'+args.val_suffix, size=args.val_size, self_loop=args.self_loop, control=False)
        # test_data_loader = DataLoader(
        #     test_data, batch_size=args.val_bs, shuffle=True)
    print('size of training dataset', len(train_data), 'size of valid dataset',
          len(valid_data))
    logger = Logger(save_folder)
    print('Doing initial validation before training...')
    nll_val, nll_lasttwo_val, kl_val, mse_val, a_val, b_val, c_val, control_constraint_loss_val, nll_lasttwo_5_val, nll_lasttwo_10_val, nll_lasttwo__1_val, nll_lasttwo_1_val = val_control(
        args, log_prior, logger, args.save_folder, valid_data_loader, -1,
        decoder)

    for epoch in range(args.epochs):
        # TODO: when len(train_dataset) reaches budget, force stop
        # print('#batches in train_dataset', len(train_dataset)/args.train_bs)
        nll, nll_lasttwo, kl, mse, control_constraint_loss, lr, rel_graphs, rel_graphs_grad, a, b, c, d, e = train_control(
            args, log_prior, optimizer, save_folder, train_data_loader,
            decoder, epoch)

        if epoch % args.train_log_freq == 0:
            logger.log('train',
                       decoder,
                       epoch,
                       nll,
                       nll_lasttwo,
                       kl=kl,
                       mse=mse,
                       control_constraint_loss=control_constraint_loss,
                       lr_train=lr,
                       rel_graphs=rel_graphs,
                       rel_graphs_grad=rel_graphs_grad,
                       msg_hook_weights=a,
                       nll_train_lasttwo=b,
                       nll_lasttwo_10_train=c,
                       nll_lasttwo__1_train=d,
                       nll_lasttwo_1_train=e)

        if epoch % args.train_log_freq == 0:
            nll_val, nll_lasttwo_val, kl_val, mse_val, a_val, b_val, c_val, control_constraint_loss_val, nll_lasttwo_5_val, nll_lasttwo_10_val, nll_lasttwo__1_val, nll_lasttwo_1_val = val_control(
                args, log_prior, logger, args.save_folder, valid_data_loader,
                epoch, decoder)

            logger.log('val',
                       decoder,
                       epoch,
                       nll_val,
                       nll_lasttwo_val,
                       kl_val=kl_val,
                       mse_val=mse_val,
                       a_val=a_val,
                       b_val=b_val,
                       c_val=c_val,
                       control_constraint_loss_val=control_constraint_loss_val,
                       nll_lasttwo_5_val=nll_lasttwo_5_val,
                       nll_lasttwo_10_val=nll_lasttwo_10_val,
                       nll_lasttwo__1_val=nll_lasttwo__1_val,
                       nll_lasttwo_1_val=nll_lasttwo_1_val,
                       scheduler=scheduler)

        # if epoch % args.val_log_freq == 0:
        #     _ = val_control(
        #         args, log_prior, logger, save_folder, valid_data_loader, epoch, decoder, scheduler)
        scheduler.step()

    print("Optimization Finished!")
    print("Best Epoch: {:04d}".format(logger.best_epoch))
    if args.save_folder:
        print("Best Epoch: {:04d}".format(logger.best_epoch), file=meta_file)
        meta_file.flush()

    test_control(test_data_loader)
    if meta_file is not None:
        print(save_folder)
        meta_file.close()