示例#1
0
文件: A2C.py 项目: karunraju/NFF
    def __init__(self, ReplayBuffer, action_space=3, network=None):
        self.lr = PARAM.LEARNING_RATE
        self.N = PARAM.N
        self.gamma = PARAM.gamma
        self.seq_len = PARAM.A2C_SEQUENCE_LENGTH
        self.aux_batch_size = PARAM.AUX_TASK_BATCH_SIZE
        self.vfr_weight = PARAM.VFR_LOSS_WEIGHT
        self.rp_weight = PARAM.RP_LOSS_WEIGHT
        self.pc_weight = PARAM.PC_LOSS_WEIGHT
        self.gpu = torch.cuda.is_available()

        # A2C network
        if PARAM.ENSEMBLE < 1:
            self.A = AuxNetwork(state_size=PARAM.STATE_SIZE,
                                action_space=action_space,
                                seq_len=self.seq_len)
            # GPU availability
            if self.gpu:
                print("Using GPU")
                self.A = self.A.cuda()
            else:
                print("Using CPU")
            self.replay_buffer = ReplayBuffer(PARAM.REPLAY_MEMORY_SIZE)
            # Loss Function and Optimizer
            self.optimizer = optim.Adam(self.A.parameters(),
                                        lr=self.lr,
                                        weight_decay=1e-6)
        else:
            self.Ensemble = Ensemble(PARAM.ENSEMBLE, action_space,
                                     self.seq_len, ReplayBuffer, network)
            self.source_context()

        self.vfr_criterion = nn.MSELoss()  # Value Function Replay loss
        self.rp_criterion = nn.CrossEntropyLoss()  # Reward Prediction loss
        self.pc_criterion = nn.MSELoss()  # Value Function Replay loss
示例#2
0
    def __init__(self, params):
        super(POLOAgent, self).__init__(params)
        self.H_backup = self.params['polo']['H_backup']

        # Create ensemble of value functions
        model_params = params['polo']['ens_params']['model_params']
        model_params['input_size'] = self.N
        model_params['output_size'] = 1

        params['polo']['ens_params']['dtype'] = self.dtype
        params['polo']['ens_params']['device'] = self.device

        self.val_ens = Ensemble(self.params['polo']['ens_params'])

        # Learn from replay buffer
        self.polo_buf = ReplayBuffer(self.N, self.M,
                                     self.params['polo']['buf_size'])

        # Value (from forward), value mean, value std
        self.hist['vals'] = np.zeros((self.T, 3))
示例#3
0
def main():
    teacher = MGA_Network(nInputChannels=3,
                          n_classes=1,
                          os=16,
                          img_backbone_type='resnet101',
                          flow_backbone_type='resnet34')
    teacher = load_MGA(teacher, args['mga_model_path'], device_id=device_id)
    teacher.eval()
    teacher.cuda(device_id2)

    student = Ensemble(device_id).cuda(device_id).train()

    # fix_parameters(net.named_parameters())
    optimizer = optim.SGD([{
        'params': [
            param for name, param in student.named_parameters()
            if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in student.named_parameters()
            if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        student.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    if len(args['pretrain']) > 0:
        print('pretrain model from ' + args['pretrain'])
        student = load_part_of_model(student,
                                     args['pretrain'],
                                     device_id=device_id)
        # fix_parameters(student.named_parameters())

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(student, teacher, optimizer)
示例#4
0
from config import Config
from utils.data_utils import read_data, write_scores, write_predictions, get_segment, build_dictionary, remove_low_words
from models.BowModel import BowModel
from models.BocModel import BocModel
from models.Ensemble import Ensemble

if __name__ == "__main__":
    config = Config()

    # read data
    train_set = read_data(config.train_set_file_name)
    dev_set = read_data(config.dev_set_file_name)

    # segmentation
    train_set = get_segment(train_set)
    dev_set = get_segment(dev_set)

    # remove words with low frequency
    dictionary = build_dictionary([train_set, dev_set], config.low_frequency,
                                  config.high_frequency)
    train_set = remove_low_words(train_set, dictionary)
    dev_set = remove_low_words(dev_set, dictionary)

    # get predictions
    ensemble_model = Ensemble(config, [BowModel(config), BocModel(config)])
    scores = ensemble_model.test(dev_set)

    # write predictions
    # write_predictions(dev_set, labels, config.result_file_name)
    write_scores(scores, config.result_file_name)
示例#5
0
    if exp_config['save_models']:
        model_dir = exp_config['save_models_path'] + args.bin_name + exp_config['ts'] + '/'
        if not os.path.isdir(model_dir):
            os.makedirs(model_dir)
        # Copying config file for book keeping
        copy2(args.config, model_dir)
        with open(model_dir+'args.json', 'w') as f:
            json.dump(vars(args), f) # converting args.namespace to dict

    float_tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    torch.manual_seed(exp_config['seed'])
    if use_cuda:
        torch.cuda.manual_seed_all(exp_config['seed'])

    # Init Model
    model = Ensemble(**ensemble_args)
    # TODO Checkpoint loading

    if use_cuda:
        model.cuda()
        model = DataParallel(model)
    print(model)

    if args.resnet:
        cnn = ResNet()

        if use_cuda:
            cnn.cuda()
            cnn = DataParallel(cnn)

    softmax = nn.Softmax(dim=-1)
    torch.manual_seed(exp_config['seed'])
    if use_cuda:
        torch.cuda.manual_seed_all(exp_config['seed'])

    if exp_config['logging']:
        log_dir = exp_config['logdir']+str(args.exp_name)+exp_config['ts']+'/'
        if not os.path.isdir(log_dir):
            os.makedirs(log_dir)
        copy2(args.config, log_dir)
        copy2(args.ens_config, log_dir)
        copy2(args.or_config, log_dir)
        with open(log_dir+'args.txt', 'w') as f:
            f.write(str(vars(args))) # converting args.namespace to dict

    model = Ensemble(**ensemble_args)
    model = load_model(model, ensemble_args['bin_file'], use_dataparallel=use_dataparallel)
    model.eval()

    oracle = Oracle(
        no_words            = oracle_args['vocab_size'],
        no_words_feat       = oracle_args['embeddings']['no_words_feat'],
        no_categories       = oracle_args['embeddings']['no_categories'],
        no_category_feat    = oracle_args['embeddings']['no_category_feat'],
        no_hidden_encoder   = oracle_args['lstm']['no_hidden_encoder'],
        mlp_layer_sizes     = oracle_args['mlp']['layer_sizes'],
        no_visual_feat      = oracle_args['inputs']['no_visual_feat'],
        no_crop_feat        = oracle_args['inputs']['no_crop_feat'],
        dropout             = oracle_args['lstm']['dropout'],
        inputs_config       = oracle_args['inputs'],
        scale_visual_to     = oracle_args['inputs']['scale_visual_to']
示例#7
0
def main():
    net = Ensemble(device_id, pretrained=False)

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    # net.load_state_dict(torch.load('pretrained/R2Net.pth', map_location='cuda:2'))
    # net = load_part_of_model2(net, 'pretrained/R2Net.pth', device_id=2)
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'),
                                   map_location='cuda:' + str(device_id)))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            for idx, img_name in enumerate(img_list):
                print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                print(img_name)

                if name == 'VOS' or name == 'DAVSOD':
                    img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                else:
                    img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                shape = img.size
                img = img.resize(args['input_size'])
                img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                start = time.time()
                outputs_a, outputs_c = net(img_var)
                a_out1u, a_out2u, a_out2r, a_out3r, a_out4r, a_out5r = outputs_a  # F3Net
                # b_outputs0, b_outputs1 = outputs_b  # CPD
                c_outputs0, c_outputs1, c_outputs2, c_outputs3, c_outputs4 = outputs_c  # RAS
                prediction = torch.sigmoid(c_outputs0)
                end = time.time()
                print('running time:', (end - start))
                # e = Erosion2d(1, 1, 5, soft_max=False).cuda()
                # prediction2 = e(prediction)
                #
                # precision2 = to_pil(prediction2.data.squeeze(0).cpu())
                # precision2 = prediction2.data.squeeze(0).cpu().numpy()
                # precision2 = precision2.resize(shape)
                # prediction2 = np.array(precision2)
                # prediction2 = prediction2.astype('float')

                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')

                # plt.style.use('classic')
                # plt.subplot(1, 2, 1)
                # plt.imshow(prediction)
                # plt.subplot(1, 2, 2)
                # plt.imshow(precision2[0])
                # plt.show()

                prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                # if args['crf_refine']:
                #     prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
    log_path = os.path.join('result_all.txt')
    open(log_path, 'a').write(exp_name + ' ' + args['snapshot'] + '\n')
    open(log_path, 'a').write(str(results) + '\n\n')
        with open(log_dir + 'args.txt', 'w') as f:
            f.write(str(vars(args)))  # converting args.namespace to dict

    if exp_config['save_models']:
        model_dir = exp_config[
            'save_models_path'] + args.bin_name + exp_config['ts'] + '/'
        if not os.path.isdir(model_dir):
            os.makedirs(model_dir)
        # This is again duplicate just for bookkeeping multiple times
        copy2(args.config, model_dir)
        copy2(args.ens_config, model_dir)
        copy2(args.or_config, model_dir)
        with open(model_dir + 'args.txt', 'w') as f:
            f.write(str(vars(args)))  # converting args.namespace to dict

    model = Ensemble(**ensemble_args)
    model = load_model(model,
                       ensemble_args['bin_file'],
                       use_dataparallel=use_dataparallel)
    # model.eval()

    oracle = Oracle(
        no_words=oracle_args['vocab_size'],
        no_words_feat=oracle_args['embeddings']['no_words_feat'],
        no_categories=oracle_args['embeddings']['no_categories'],
        no_category_feat=oracle_args['embeddings']['no_category_feat'],
        no_hidden_encoder=oracle_args['lstm']['no_hidden_encoder'],
        mlp_layer_sizes=oracle_args['mlp']['layer_sizes'],
        no_visual_feat=oracle_args['inputs']['no_visual_feat'],
        no_crop_feat=oracle_args['inputs']['no_crop_feat'],
        dropout=oracle_args['lstm']['dropout'],
示例#9
0
class POLOAgent(MPCAgent):
    """
    MPC-based agent that uses the Plan Online, Learn Offline (POLO) framework
    (Lowrey et. al. 2018) for trajectory optimization.
    """
    def __init__(self, params):
        super(POLOAgent, self).__init__(params)
        self.H_backup = self.params['polo']['H_backup']

        # Create ensemble of value functions
        model_params = params['polo']['ens_params']['model_params']
        model_params['input_size'] = self.N
        model_params['output_size'] = 1

        params['polo']['ens_params']['dtype'] = self.dtype
        params['polo']['ens_params']['device'] = self.device

        self.val_ens = Ensemble(self.params['polo']['ens_params'])

        # Learn from replay buffer
        self.polo_buf = ReplayBuffer(self.N, self.M,
                                     self.params['polo']['buf_size'])

        # Value (from forward), value mean, value std
        self.hist['vals'] = np.zeros((self.T, 3))

    def get_action(self, prior=None):
        """
        POLO selects action based on MPC optimization with an optimistic
        terminal value function.
        """
        self.val_ens.eval()

        # Get value of current state
        s = torch.tensor(self.prev_obs, dtype=self.dtype)
        s = s.to(device=self.device)
        current_val = self.val_ens.forward(s)[0]
        current_val = torch.squeeze(current_val, -1)
        current_val = current_val.detach().cpu().numpy()

        # Get prediction of every function in ensemble
        preds = self.val_ens.get_preds_np(self.prev_obs)

        # Log information from value function
        self.hist['vals'][self.time] = \
            np.array([current_val, np.mean(preds), np.std(preds)])

        # Run MPC to get action
        act = super(POLOAgent, self).get_action(terminal=self.val_ens,
                                                prior=prior)

        return act

    def action_taken(self, prev_obs, obs, rew, done, ifo):
        """
        Update buffer for value function learning.
        """
        self.polo_buf.update(prev_obs, obs, rew, done)

    def do_updates(self):
        """
        POLO learns a value function from its past true history of interactions
        with the environment.
        """
        super(POLOAgent, self).do_updates()
        if self.time % self.params['polo']['update_freq'] == 0:
            self.val_ens.update_from_buf(self.polo_buf,
                                         self.params['polo']['grad_steps'],
                                         self.params['polo']['batch_size'],
                                         self.params['polo']['H_backup'],
                                         self.gamma)

    def print_logs(self):
        """
        POLO-specific logging information.
        """
        bi, ei = super(POLOAgent, self).print_logs()

        self.print('POLO metrics', mode='head')

        self.print('current state val', self.hist['vals'][self.time - 1][0])
        self.print('current state std', self.hist['vals'][self.time - 1][2])

        return bi, ei
示例#10
0
文件: A2C.py 项目: karunraju/NFF
class A2C():
    def __init__(self, ReplayBuffer, action_space=3, network=None):
        self.lr = PARAM.LEARNING_RATE
        self.N = PARAM.N
        self.gamma = PARAM.gamma
        self.seq_len = PARAM.A2C_SEQUENCE_LENGTH
        self.aux_batch_size = PARAM.AUX_TASK_BATCH_SIZE
        self.vfr_weight = PARAM.VFR_LOSS_WEIGHT
        self.rp_weight = PARAM.RP_LOSS_WEIGHT
        self.pc_weight = PARAM.PC_LOSS_WEIGHT
        self.gpu = torch.cuda.is_available()

        # A2C network
        if PARAM.ENSEMBLE < 1:
            self.A = AuxNetwork(state_size=PARAM.STATE_SIZE,
                                action_space=action_space,
                                seq_len=self.seq_len)
            # GPU availability
            if self.gpu:
                print("Using GPU")
                self.A = self.A.cuda()
            else:
                print("Using CPU")
            self.replay_buffer = ReplayBuffer(PARAM.REPLAY_MEMORY_SIZE)
            # Loss Function and Optimizer
            self.optimizer = optim.Adam(self.A.parameters(),
                                        lr=self.lr,
                                        weight_decay=1e-6)
        else:
            self.Ensemble = Ensemble(PARAM.ENSEMBLE, action_space,
                                     self.seq_len, ReplayBuffer, network)
            self.source_context()

        self.vfr_criterion = nn.MSELoss()  # Value Function Replay loss
        self.rp_criterion = nn.CrossEntropyLoss()  # Reward Prediction loss
        self.pc_criterion = nn.MSELoss()  # Value Function Replay loss

    def add_episodic_buffer(self, episode_buffer):
        self.episode_buffer = episode_buffer

    def reduce_learning_rate(self):
        if PARAM.ENSEMBLE >= 1:
            self.Ensemble.reduce_learning_rate()
            return

        for pgroups in self.optimizer.param_groups:
            pgroups['lr'] = pgroups['lr'] / 10.0

    def source_context(self):
        if PARAM.ENSEMBLE == 0:
            return
        self.Ensemble.update_context()
        self.A = self.Ensemble.get_network()
        self.optimizer = self.Ensemble.get_optimizer()
        self.replay_buffer = self.Ensemble.get_replay_buffer()

    def reduce_learning_rate(self):
        for pgroups in self.optimizer.param_groups:
            pgroups['lr'] = pgroups['lr'] / 10.0

    def train(self, episode_number, episode_len):
        self.optimizer.zero_grad()
        loss = self.compute_A2C_loss(episode_len)
        if episode_number > 5:
            loss += self.vfr_weight * self.compute_vfr_loss()
            if self.replay_buffer.any_reward_instances():
                loss += self.rp_weight * self.compute_rp_loss()
            loss += self.pc_weight * self.compute_pc_loss()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.A.parameters(),
                                        PARAM.GRAD_CLIP_VAL)
        self.optimizer.step()

        if math.isnan(loss.item()):
            print('Loss Eploded!')

    def compute_A2C_loss(self, episode_len):
        T = episode_len
        n = self.N
        for t in range(T - 1, -1, -1):
            val = self.episode_buffer[t][-1]
            if t + n >= T:
                Vend = 0
            else:
                Vend = self.episode_buffer[t + n][-1]
            sum_ = 0.0
            for k in range(n):
                if t + k < T:
                    tk_reward = self.episode_buffer[t + k][2]
                    sum_ += tk_reward * (self.gamma**k)
            rew = Vend * (self.gamma**n) + float(sum_)
            if t == T - 1:
                ploss = (rew - val) * torch.log(
                    self.episode_buffer[t][4][self.episode_buffer[t][1]])
                vloss = (rew - val)**2
            else:
                ploss += (rew - val) * torch.log(
                    self.episode_buffer[t][4][self.episode_buffer[t][1]])
                vloss += (rew - val)**2

        ploss = -1.0 * ploss / float(T)
        vloss = vloss / float(T)

        return ploss + vloss

    def compute_vfr_loss(self):
        """ Computes Value Function Replay Loss. """
        idxs = self.replay_buffer.sample_idxs(self.aux_batch_size)
        vision, scent, state, reward = self.get_io_from_replay_buffer(
            idxs, batch_size=self.aux_batch_size, seq_len=self.seq_len)
        val, _ = self.A.forward(vision, scent, state)

        return self.vfr_criterion(val.view(-1, 1), reward)

    def compute_rp_loss(self):
        """ Computes Reward Prediction Loss. """
        vision, ground_truth = self.get_io_from_skewed_replay_buffer(
            batch_size=self.aux_batch_size, seq_len=3)
        pred = self.A.predict_rewards(vision)

        return self.rp_criterion(pred, ground_truth)

    def compute_pc_loss(self):
        """ Computes Pixel Control Loss. """
        idxs = self.replay_buffer.sample_idxs(self.aux_batch_size)
        vision, aux_rew, actions = self.get_pc_io_from_replay_buffer(
            idxs, batch_size=self.aux_batch_size, seq_len=1)
        pred = self.A.pixel_control(vision)
        for i in range(20):
            if i == 0:
                pc_loss = self.pc_criterion(aux_rew[i], pred[i, actions[i]])
            else:
                pc_loss += self.pc_criterion(aux_rew[i], pred[i, actions[i]])

        return pc_loss

    def get_output(self, index, batch_size=1, seq_len=1, no_grad=False):
        ''' Returns output from the A network. '''
        vision, scent, state = self.get_input_tensor(index, batch_size,
                                                     seq_len)
        if PARAM.ENSEMBLE != 0:
            self.A = self.Ensemble.get_network()
            self.optimizer = self.Ensemble.get_optimizer()
        if no_grad:
            with torch.no_grad():
                val, softmax = self.A.forward(vision, scent, state)
        else:
            val, softmax = self.A.forward(vision, scent, state)

        action = np.random.choice(np.arange(3),
                                  1,
                                  p=np.squeeze(
                                      softmax.clone().cpu().detach().numpy()))
        return val, softmax.view(3), action

    def get_input_tensor(self, idxs, batch_size=1, seq_len=1):
        ''' Returns an input tensor from the observation. '''
        vision = np.zeros((batch_size, seq_len, 3, 11, 11))
        scent = np.zeros((batch_size, seq_len, 3))
        state = np.zeros((batch_size, seq_len, 4))

        for k, idx in enumerate(idxs):
            for j in range(seq_len):
                if idx - j < 0:
                    continue
                obs, action, rew, _, _, tong_count, _ = self.episode_buffer[idx
                                                                            -
                                                                            j]
                vision[k, j] = np.moveaxis(obs['vision'], -1, 0)
                scent[k, j] = obs['scent']
                state[k, j] = np.array(
                    [action, rew, int(obs['moved']), tong_count])

        vision, scent, state = torch.from_numpy(vision).float(
        ), torch.from_numpy(scent).float(), torch.from_numpy(state).float()
        if self.gpu:
            vision, scent, state = vision.cuda(), scent.cuda(), state.cuda()

        return vision, scent, state

    def get_io_from_replay_buffer(self, idxs, batch_size=1, seq_len=1):
        ''' Returns an input tensor from the observation. '''
        vision = np.zeros((batch_size, seq_len, 3, 11, 11))
        scent = np.zeros((batch_size, seq_len, 3))
        state = np.zeros((batch_size, seq_len, 4))
        reward = np.zeros((batch_size, 1))

        for k, idx in enumerate(idxs):
            for j in range(seq_len):
                obs, action, rew, _, _, tong_count = self.replay_buffer.get_single_sample(
                    idx - j)
                vision[k, j] = np.moveaxis(obs['vision'], -1, 0)
                scent[k, j] = obs['scent']
                state[k, j] = np.array(
                    [action, rew, int(obs['moved']), tong_count])
                if j == 0:
                    reward[k] = rew

        vision, scent, state, reward = torch.from_numpy(
            vision).float(), torch.from_numpy(scent).float(), torch.from_numpy(
                state).float(), torch.from_numpy(reward).float()
        if self.gpu:
            vision, scent, state, reward = vision.cuda(), scent.cuda(
            ), state.cuda(), reward.cuda()

        return vision, scent, state, reward

    def get_io_from_skewed_replay_buffer(self, batch_size=1, seq_len=1):
        ''' Returns an input tensor from the observation. '''
        vision, reward_class = self.replay_buffer.skewed_samples(
            batch_size, seq_len)
        vision, reward_class = torch.from_numpy(
            vision).float(), torch.from_numpy(reward_class).long()
        if self.gpu:
            vision, reward_class = vision.cuda(), reward_class.cuda()

        return vision, reward_class

    def get_pc_io_from_replay_buffer(self, idxs, batch_size=1, seq_len=1):
        ''' Returns an input tensor from the observation. '''
        vision = np.zeros((batch_size, seq_len, 3, 11, 11))
        aux_rew = np.zeros((batch_size, 11, 11))
        actions = [[]] * batch_size

        for k, idx in enumerate(idxs):
            for j in range(seq_len):
                obs, action, _, next_obs, _, _ = self.replay_buffer.get_single_sample(
                    idx - j)
                vision[k, j] = np.moveaxis(obs['vision'], -1, 0)
                if j == 0:
                    if next_obs['moved']:
                        aux_rew[k] = np.mean(np.abs(obs['vision'] -
                                                    next_obs['vision']),
                                             axis=2)
                    actions[k] = action

        vision, aux_rew = torch.from_numpy(vision).float(), torch.from_numpy(
            aux_rew).float()
        if self.gpu:
            vision, aux_rew = vision.cuda(), aux_rew.cuda()

        return vision, aux_rew, actions

    def set_train(self):
        self.A.train()

    def set_eval(self):
        self.A.eval()

    def save_model_weights(self, suffix, path='./'):
        # Helper function to save your model / weights.
        if PARAM.ENSEMBLE != 0:
            self.Ensemble.save()
            return
        state = {
            'epoch': suffix,
            'state_dict': self.A.state_dict(),
            'optmizer': self.optimizer.state_dict(),
        }
        torch.save(state, path + str(suffix) + '.dat')

    def load_model(self, model_file):
        # Helper function to load an existing model.
        if PARAM.ENSEMBLE != 0:
            print('Loading Ensemble models')
            self.Ensemble.load()
            return
        #state = torch.load(model_file)
        #self.A.load_state_dict(state['state_dict'])
        #self.optimizer.load_state_dict(state['optimizer'])
        self.A.load(model_file)

    def get_replay_buffer(self):
        if PARAM.ENSEMBLE != 0:
            return self.Ensemble.get_replay_buffer()
        return self.replay_buffer

    def get_action_repeat(self):
        if PARAM.ENSEMBLE != 0:
            return self.Ensemble.get_action_repeat()
        return PARAM.ACTION_REPEAT

    def monitor(self, rewards_list):
        if PARAM.ENSEMBLE == 0:
            return
        self.Ensemble.analyze_rewards(rewards_list)