Ejemplo n.º 1
0
    #                      betas=(0.5, 0.9))

    best_loss = np.inf
    for _epoch in range(N_epoch):
        # adjust learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] *= gamma

        optimizer.zero_grad()

        # forward and optimizer
        prediction_chunks, data_chunks = multi_shoot.fit_and_grad( input_tensor, time_points )
        loss = multi_shoot.get_loss( prediction_chunks, data_chunks )

        loss.backward(retain_graph=False)
        optimizer.step()

        print( 'Epoch {} Alpha {}, Beta {}, Gamma {}, Delta {}'.format(_epoch,dcmfunc.alpha, dcmfunc.beta, dcmfunc.gamma, dcmfunc.delta))

        if loss.item()<best_loss:
            # concatenate by time, and plot
            prediction2, data2 = [], []
            for prediction, data in zip(prediction_chunks, data_chunks):
                if data.shape[0] > 1:
                    prediction2.append(prediction[:-1, ...])
                    data2.append(data[:-1, ...])

            prediction_all = torch.cat(prediction2, 0)
            data_all = torch.cat(data2, 0)

            best_loss = loss.item()
Ejemplo n.º 2
0
def memo_valor(env_fn,
                model=MEMO,
                  memo_kwargs=dict(),
                  annealing_kwargs=dict(),
                  seed=0,
                  episodes_per_expert=40,
                  epochs=50,
                  # warmup=10,
                  train_iters=5,
                  step_size=5,
                  memo_lr=1e-3,
                  train_batch_size=50,
                  eval_batch_size=200,
                  max_ep_len=1000,
                  logger_kwargs=dict(),
                  config_name='standard',
                  save_freq=10,
               # replay_buffers=[],
               memories=[]):
    # W&B Logging
    wandb.login()

    composite_name = 'E ' + str(epochs) + ' B ' + str(train_batch_size) + ' ENC ' + \
                     str(memo_kwargs['encoder_hidden']) + 'DEC ' + str(memo_kwargs['decoder_hidden'])

    wandb.init(project="MEMO", group='Epochs: ' + str(epochs),  name=composite_name, config=locals())

    assert memories != [], "No examples found! Replay/memory buffers must be set to proceed."

    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    setup_pytorch_for_mpi()

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Model    # Create discriminator and monitor it
    con_dim = len(memories)
    memo = model(obs_dim=obs_dim[0], out_dim=act_dim[0], **memo_kwargs)

    # Set up model saving
    logger.setup_pytorch_saver([memo])

    # Sync params across processes
    sync_params(memo)
    N_expert = episodes_per_expert*max_ep_len
    print("N Expert: ", N_expert)

    # Buffer
    # local_episodes_per_epoch = int(episodes_per_epoch / num_procs())
    local_iter_per_epoch = int(train_iters / num_procs())

    # Count variables
    var_counts = tuple(count_vars(module) for module in [memo])
    logger.log('\nNumber of parameters: \t d: %d\n' % var_counts)

    # Optimizers
    # memo_optimizer = AdaBelief(memo.parameters(), lr=memo_lr, eps=1e-20, rectify=True)
    memo_optimizer = AdaBelief(memo.parameters(), lr=memo_lr, eps=1e-16, rectify=True)
    # memo_optimizer = Adam(memo.parameters(), lr=memo_lr, betas=(0.9, 0.98), eps=1e-9)

    start_time = time.time()

    # Prepare data
    mem = MemoryBatch(memories, step=step_size)

    # transition_states, pure_states, transition_actions, expert_ids = mem.collate()
    transition_states, pure_states, transition_actions, expert_ids = mem.collate()
    total_l_old, recon_l_old, context_l_old = 0, 0, 0

    # Main Loop
    kl_beta_schedule = frange_cycle_sigmoid(epochs, **annealing_kwargs)

    for epoch in range(epochs):
        memo.train()

        # Select state transitions and actions at random indexes
        batch_indexes = torch.randint(len(transition_states), (train_batch_size,))

        raw_states_batch, delta_states_batch, actions_batch, sampled_experts = \
           pure_states[batch_indexes], transition_states[batch_indexes], transition_actions[batch_indexes], expert_ids[batch_indexes]


        for i in range(local_iter_per_epoch):
            # kl_beta = kl_beta_schedule[epoch]
            kl_beta = 1
            # only take context labeling into account for first label
            loss, recon_loss, X, latent_labels, vq_loss = memo(raw_states_batch, delta_states_batch,  actions_batch,
                                                                     kl_beta)
            memo_optimizer.zero_grad()
            loss.mean().backward()
            mpi_avg_grads(memo)
            memo_optimizer.step()

        # scheduler.step(loss.mean().data.item())

        total_l_new, recon_l_new, vq_l_new = loss.mean().data.item(), recon_loss.mean().data.item(), vq_loss.mean().data.item()

        memo_metrics = {'MEMO Loss': total_l_new, 'Recon Loss': recon_l_new, "VQ Labeling Loss": vq_l_new,
                        "KL Beta": kl_beta_schedule[epoch]}
        wandb.log(memo_metrics)

        logger.store(TotalLoss=total_l_new, PolicyLoss=recon_l_new, # ContextLoss=context_l_new,
                     DeltaTotalLoss=total_l_new-total_l_old, DeltaPolicyLoss=recon_l_new-recon_l_old,
                     )

        total_l_old, recon_l_old = total_l_new, recon_l_new  # , context_l_new

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, [memo], None)

        # Log
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpochBatchSize', train_batch_size)
        logger.log_tabular('TotalLoss', average_only=True)
        logger.log_tabular('PolicyLoss', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()

    print("Finished training, and detected %d contexts!" % len(memo.found_contexts))
    # wandb.finish()
    print('memo type', memo)
    return memo, mem
Ejemplo n.º 3
0
def train_mlp(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    assert args.cae_weight, "No trained cae weight"
    cae = CAE().to(device)
    cae.eval()
    cae.load_state_dict(torch.load(args.cae_weight))

    print('a')
    train_dataset = PathDataSet(S2D_data_path, cae.encoder)
    val_dataset = PathDataSet(S2D_data_path, cae.encoder, is_val=True)

    print('b')
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False)

    now = datetime.now()
    output_folder = args.output_folder + '/' + now.strftime(
        '%Y-%m-%d_%H-%M-%S')
    check_and_create_dir(output_folder)

    model = MLP(args.input_size, args.output_size).to(device)
    if args.load_weights:
        print("Load weight from {}".format(args.load_weights))
        model.load_state_dict(torch.load(args.load_weights))

    criterion = nn.MSELoss()
    # optimizer = torch.optim.Adagrad(model.parameters())
    optimizer = AdaBelief(model.parameters(),
                          lr=1e-4,
                          eps=1e-10,
                          betas=(0.9, 0.999),
                          weight_decouple=True,
                          rectify=False)

    for epoch in range(args.max_epoch):
        model.train()

        for i, data in enumerate(tqdm(train_loader)):
            # get data
            input_data = data[0].to(device)  # B, 32
            next_config = data[1].to(device)  # B, 2

            # predict
            predict_config = model(input_data)

            # get loss
            loss = criterion(predict_config, next_config)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()

            neptune.log_metric("batch_loss", loss.item())

        print('\ncalculate validation accuracy..')

        model.eval()
        with torch.no_grad():
            losses = []
            for i, data in enumerate(tqdm(val_loader)):
                # get data
                input_data = data[0].to(device)  # B, 32
                next_config = data[1].to(device)  # B, 2

                # predict
                predict_config = model(input_data)

                # get loss
                loss = criterion(predict_config, next_config)

                losses.append(loss.item())

            val_loss = np.mean(losses)
            neptune.log_metric("val_loss", val_loss)

        print("validation result, epoch {}: {}".format(epoch, val_loss))
        if epoch % 5 == 0:
            torch.save(model.state_dict(),
                       '{}/epoch_{}.tar'.format(output_folder, epoch))
Ejemplo n.º 4
0
class BaseExperiment:
    """
    Implements a base experiment class for Aspect-Based Sentiment Analysis
    """
    def __init__(self, args):
        self.args = args
        torch.manual_seed(self.args.seed)
        if self.args.device == "cuda":
            torch.cuda.set_device(self.args.gpu)
            torch.cuda.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        random.seed(self.args.seed)
        print('> training arguments:')
        for arg in vars(args):
            print('>>> {0}: {1}'.format(arg, getattr(args, arg)))

        tripadvisor_dataset = TripadvisorDatasetReader(
            dataset=args.dataset,
            embed_dim=args.embed_dim,
            max_seq_len=args.max_seq_len)
        if self.args.dev > 0.0:
            random.shuffle(tripadvisor_dataset.train_data.data)
            dev_num = int(
                len(tripadvisor_dataset.train_data.data) * self.args.dev)
            tripadvisor_dataset.dev_data.data = tripadvisor_dataset.train_data.data[:
                                                                                    dev_num]
            tripadvisor_dataset.train_data.data = tripadvisor_dataset.train_data.data[
                dev_num:]

        # print(len(absa_dataset.train_data.data), len(absa_dataset.dev_data.data))

        self.train_data_loader = DataLoader(
            dataset=tripadvisor_dataset.train_data,
            batch_size=args.batch_size,
            shuffle=True)
        if self.args.dev > 0.0:
            self.dev_data_loader = DataLoader(
                dataset=tripadvisor_dataset.dev_data,
                batch_size=len(tripadvisor_dataset.dev_data),
                shuffle=False)
        self.test_data_loader = DataLoader(
            dataset=tripadvisor_dataset.test_data,
            batch_size=len(tripadvisor_dataset.test_data),
            shuffle=False)
        self.target_data_loader = DataLoader(
            dataset=tripadvisor_dataset.test_data,
            batch_size=len(tripadvisor_dataset.test_data),
            shuffle=False)
        self.mdl = args.model_class(
            self.args,
            embedding_matrix=tripadvisor_dataset.embedding_matrix,
            aspect_embedding_matrix=tripadvisor_dataset.aspect_embedding_matrix
        )
        self.reset_parameters()
        self.mdl.encoder.weight.requires_grad = True
        self.mdl.encoder_aspect.weight.requires_grad = True
        self.mdl.to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.learning_history = {}

    def reset_parameters(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.mdl.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
                if len(p.shape) > 1:
                    self.args.initializer(p)
            else:
                n_nontrainable_params += n_params
        print('n_trainable_params: {0}, n_nontrainable_params: {1}'.format(
            n_trainable_params, n_nontrainable_params))

    def select_optimizer(self):
        if self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                               self.mdl.parameters()),
                                        lr=self.args.learning_rate,
                                        weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'AdaBelief':
            self.optimizer = AdaBelief(self.mdl.parameters(),
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'RMS':
            self.optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                                  self.mdl.parameters()),
                                           lr=self.args.learning_rate)
        elif self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.mdl.parameters()),
                                       lr=self.args.learning_rate,
                                       momentum=0.9)
        elif self.args.optimizer == 'Adagrad':
            self.optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                                  self.mdl.parameters()),
                                           lr=self.args.learning_rate)
        elif self.args.optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(filter(lambda p: p.requires_grad,
                                                   self.mdl.parameters()),
                                            lr=self.args.learning_rate)

    def load_model(self, path):
        # mdl_best = self.load_model(PATH)
        # best_model_state = mdl_best.state_dict()
        # model_state = self.mdl.state_dict()
        # best_model_state = {k: v for k, v in best_model_state.iteritems() if
        #                     k in model_state and v.size() == model_state[k].size()}
        # model_state.update(best_model_state)
        # self.mdl.load_state_dict(model_state)
        return torch.load(path)

    def train_batch(self, sample_batched):
        self.mdl.zero_grad()
        inputs = [
            sample_batched[col].to(device) for col in self.args.inputs_cols
        ]
        targets = sample_batched['polarity'].to(device)
        outputs = self.mdl(inputs)
        loss = self.criterion(outputs, targets)
        loss.backward()
        clip_gradient(self.mdl.parameters(), 1.0)
        self.optimizer.step()
        # return loss.data[0]
        return loss.data

    def evaluation(self, x):
        inputs = [x[col].to(device) for col in self.args.inputs_cols]
        targets = x['polarity'].to(device)
        outputs = self.mdl(inputs)
        outputs = tensor_to_numpy(outputs)
        targets = tensor_to_numpy(targets)
        outputs = np.argmax(outputs, axis=1)
        return outputs, targets

    def metric(self, targets, outputs, save_path=None):
        dist = dict(Counter(outputs))
        acc = accuracy_score(targets, outputs)
        macro_recall = recall_score(targets,
                                    outputs,
                                    labels=[0, 1, 2],
                                    average='macro')
        macro_precision = precision_score(targets,
                                          outputs,
                                          labels=[0, 1, 2],
                                          average='macro')
        macro_f1 = f1_score(targets,
                            outputs,
                            labels=[0, 1, 2],
                            average='macro')
        weighted_recall = recall_score(targets,
                                       outputs,
                                       labels=[0, 1, 2],
                                       average='weighted')
        weighted_precision = precision_score(targets,
                                             outputs,
                                             labels=[0, 1, 2],
                                             average='weighted')
        weighted_f1 = f1_score(targets,
                               outputs,
                               labels=[0, 1, 2],
                               average='weighted')
        micro_recall = recall_score(targets,
                                    outputs,
                                    labels=[0, 1, 2],
                                    average='micro')
        micro_precision = precision_score(targets,
                                          outputs,
                                          labels=[0, 1, 2],
                                          average='micro')
        micro_f1 = f1_score(targets,
                            outputs,
                            labels=[0, 1, 2],
                            average='micro')
        recall = recall_score(targets, outputs, labels=[0, 1, 2], average=None)
        precision = precision_score(targets,
                                    outputs,
                                    labels=[0, 1, 2],
                                    average=None)
        f1 = f1_score(targets, outputs, labels=[0, 1, 2], average=None)
        result = {
            'acc': acc,
            'recall': recall,
            'precision': precision,
            'f1': f1,
            'macro_recall': macro_recall,
            'macro_precision': macro_precision,
            'macro_f1': macro_f1,
            'micro_recall': micro_recall,
            'micro_precision': micro_precision,
            'micro_f1': micro_f1,
            'weighted_recall': weighted_recall,
            'weighted_precision': weighted_precision,
            'weighted_f1': weighted_f1
        }
        # print("Output Distribution={}, Acc: {}, Macro-F1: {}".format(dist, acc, macro_f1))
        if save_path is not None:
            f_to = open(save_path, 'w')
            f_to.write("lr: {}\n".format(self.args.learning_rate))
            f_to.write("batch_size: {}\n".format(self.args.batch_size))
            f_to.write("opt: {}\n".format(self.args.optimizer))
            f_to.write("max_sentence_len: {}\n".format(self.args.max_seq_len))
            f_to.write(
                "end params -----------------------------------------------------------------\n"
            )
            for key in result.keys():
                f_to.write("{}: {}\n".format(key, result[key]))
            f_to.write(
                "end metrics -----------------------------------------------------------------\n"
            )
            for i in range(len(outputs)):
                f_to.write("{}: {},{}\n".format(i, outputs[i], targets[i]))
            f_to.write(
                "end ans -----------------------------------------------------------------\n"
            )
            f_to.close()
        return result

    def train(self):
        best_acc = 0.0
        best_result = None
        global_step = 0
        self.select_optimizer()
        losses_train = []
        accuracy_train, accuracy_validation = [], []
        for epoch in range(self.args.num_epoch):
            losses = []
            self.mdl.train()
            t0 = time.time()
            outputs_train, targets_train = None, None
            for i_batch, sample_batched in enumerate(self.train_data_loader):
                global_step += 1
                loss = self.train_batch(sample_batched)
                losses.append(loss)
                output_train, target_train = self.evaluation(sample_batched)
                if outputs_train is None:
                    outputs_train = output_train
                else:
                    outputs_train = np.concatenate(
                        (outputs_train, output_train))

                if targets_train is None:
                    targets_train = target_train
                else:
                    targets_train = np.concatenate(
                        (targets_train, target_train))
            results_train = self.metric(targets=targets_train,
                                        outputs=outputs_train)
            t1 = time.time()
            self.mdl.eval()
            if self.args.dev > 0.0:
                outputs, targets = None, None
                with torch.no_grad():
                    for d_batch, d_sample_batched in enumerate(
                            self.dev_data_loader):
                        output, target = self.evaluation(d_sample_batched)
                        if outputs is None:
                            outputs = output
                        else:
                            outputs = np.concatenate((outputs, output))

                        if targets is None:
                            targets = target
                        else:
                            targets = np.concatenate((targets, target))
                    result = self.metric(targets=targets, outputs=outputs)
                    if result['acc'] > best_acc:
                        best_acc = result['acc']
                        path = save_path + 'models/{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.model'. \
                            format(self.args.model_name,
                                   self.args.dataset,
                                   self.args.optimizer,
                                   self.args.learning_rate,
                                   self.args.weight_decay,
                                   self.args.dropout,
                                   self.args.batch_normalizations,
                                   self.args.softmax,
                                   self.args.batch_size,
                                   self.args.dev)
                        torch.save(self.mdl.state_dict(), path)
                        best_result = result
            else:
                outputs, targets = None, None
                with torch.no_grad():
                    for t_batch, t_sample_batched in enumerate(
                            self.test_data_loader):
                        output, target = self.evaluation(t_sample_batched)
                        if outputs is None:
                            outputs = output
                        else:
                            outputs = np.concatenate((outputs, output))
                        if targets is None:
                            targets = target
                        else:
                            targets = np.concatenate((targets, target))
                    result = self.metric(targets=targets, outputs=outputs)
                    if result['acc'] > best_acc:
                        best_acc = result['acc']
                        path = save_path + 'models/{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.model'. \
                            format(self.args.model_name,
                                   self.args.dataset,
                                   self.args.optimizer,
                                   self.args.learning_rate,
                                   self.args.weight_decay,
                                   self.args.dropout,
                                   self.args.batch_normalizations,
                                   self.args.softmax,
                                   self.args.batch_size,
                                   self.args.dev)
                        torch.save(self.mdl.state_dict(), path)
                        best_result = result
            print('\033[1;31m[Epoch {:>4}]\033[0m  '
                  '\033[1;31mTrain loss={:.5f}\033[0m  '
                  '\033[1;32mTrain accuracy={:.2f}%\033[0m  '
                  '\033[1;33mValidation accuracy={:.2f}%\033[0m  '
                  'Time cost={:.2f}s'.format(epoch + 1, np.mean(losses),
                                             results_train['acc'] * 100,
                                             result['acc'] * 100, t1 - t0))
            losses_train.append(np.mean(losses))
            accuracy_train.append(results_train['acc'])
            accuracy_validation.append(result['acc'])
        self.learning_history['Loss'] = np.array(losses_train).tolist()
        self.learning_history['Training accuracy'] = np.array(
            accuracy_train).tolist()
        self.learning_history['Validation accuracy'] = np.array(
            accuracy_validation).tolist()
        self.learning_history['Best Validation accuracy'] = best_acc

    def test(self):
        path = save_path + 'models/{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.model'. \
            format(self.args.model_name,
                   self.args.dataset,
                   self.args.optimizer,
                   self.args.learning_rate,
                   self.args.weight_decay,
                   self.args.dropout,
                   self.args.batch_normalizations,
                   self.args.softmax,
                   self.args.batch_size,
                   self.args.dev)
        self.mdl.load_state_dict(self.load_model(path))
        self.mdl.eval()
        outputs, targets = None, None
        with torch.no_grad():
            for t_batch, t_sample_batched in enumerate(self.test_data_loader):
                output, target = self.evaluation(t_sample_batched)
                if outputs is None:
                    outputs = output
                else:
                    outputs = np.concatenate((outputs, output))

                if targets is None:
                    targets = target
                else:
                    targets = np.concatenate((targets, target))
        result = self.metric(targets=targets, outputs=output)
        print('\033[1;32mTest accuracy:{:.2f}%, macro_f1:{:.5f}\033[0m'.format(
            result['acc'] * 100, result['macro_f1']))
        self.learning_history['Test accuracy'] = result['acc']
        # Plot confusion matrix
        class_names = ['negative', 'neutral', 'positive']
        cnf_matrix = confusion_matrix(targets, outputs)
        plot_confusion_matrix(cnf_matrix,
                              classes=class_names,
                              title='Confusion matrix',
                              normalize=False)
        plt.savefig('./result/figures/'
                    '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.png'.format(
                        self.args.model_name, self.args.dataset,
                        self.args.optimizer, self.args.learning_rate,
                        self.args.weight_decay, self.args.dropout,
                        self.args.batch_normalizations, self.args.softmax,
                        self.args.batch_size, self.args.dev))

    def save_learning_history(self):
        data = json.dumps(self.learning_history, indent=2)
        with open(
                './result/learning history/'
                '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.json'.format(
                    self.args.model_name, self.args.dataset,
                    self.args.optimizer, self.args.learning_rate,
                    self.args.weight_decay, self.args.dropout,
                    self.args.batch_normalizations, self.args.softmax,
                    self.args.batch_size, self.args.dev), 'w') as f:
            f.write(data)

    def transfer_learning(self):
        model_path = save_path + 'models/{}_{}_{}_{}_{}_{}_{}_{}_{}.model'.format(
            self.args.model_name, self.args.pre_trained_model,
            self.args.optimizer, self.args.learning_rate,
            self.args.max_seq_len, self.args.dropout, self.args.softmax,
            self.args.batch_size, self.args.dev)
        self.mdl.load_state_dict(self.load_model(model_path))
Ejemplo n.º 5
0
class Agent():
  def __init__(self, args, env):
    self.action_space = env.action_space()
    self.atoms = args.atoms
    self.Vmin = args.V_min
    self.Vmax = args.V_max
    self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device)  # Support (range) of z
    self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
    self.batch_size = args.batch_size
    self.n = args.multi_step
    self.discount = args.discount
    self.norm_clip = args.norm_clip

    self.online_net = DQN(args, self.action_space).to(device=args.device)
    if args.model:  # Load pretrained model if provided
      if os.path.isfile(args.model):
        state_dict = torch.load(args.model, map_location='cpu')  # Always load tensors onto CPU by default, will shift to GPU if necessary
        if 'conv1.weight' in state_dict.keys():
          for old_key, new_key in (('conv1.weight', 'convs.0.weight'), ('conv1.bias', 'convs.0.bias'), ('conv2.weight', 'convs.2.weight'), ('conv2.bias', 'convs.2.bias'), ('conv3.weight', 'convs.4.weight'), ('conv3.bias', 'convs.4.bias')):
            state_dict[new_key] = state_dict[old_key]  # Re-map state dict for old pretrained models
            del state_dict[old_key]  # Delete old keys for strict load_state_dict
        self.online_net.load_state_dict(state_dict)
        print("Loading pretrained model: " + args.model)
      else:  # Raise error if incorrect model path provided
        raise FileNotFoundError(args.model)

    self.online_net.train()

    self.target_net = DQN(args, self.action_space).to(device=args.device)
    self.update_target_net()
    self.target_net.train()
    for param in self.target_net.parameters():
      param.requires_grad = False

    self.optimiser = AdaBelief(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps, rectify=True)#optim.Adam(self.online_net.parameters(), lr=args.learning_rate, eps=args.adam_eps)

  # Resets noisy weights in all linear layers (of online net only)
  def reset_noise(self):
    self.online_net.reset_noise()

  # Acts based on single state (no batch)
  def act(self, state):
    with torch.no_grad():
      return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item()

  # Acts with an ε-greedy policy (used for evaluation only)
  def act_e_greedy(self, state, epsilon=0.001):  # High ε can reduce evaluation scores drastically
    return np.random.randint(0, self.action_space) if np.random.random() < epsilon else self.act(state)

  def learn(self, mem):
    # Sample transitions
    idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)

    # Calculate current state probabilities (online network noise already sampled)
    log_ps = self.online_net(states, log=True)  # Log probabilities log p(s_t, ·; θonline)
    log_ps_a = log_ps[range(self.batch_size), actions]  # log p(s_t, a_t; θonline)

    with torch.no_grad():
      # Calculate nth next state probabilities
      pns = self.online_net(next_states)  # Probabilities p(s_t+n, ·; θonline)
      dns = self.support.expand_as(pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
      argmax_indices_ns = dns.sum(2).argmax(1)  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      self.target_net.reset_noise()  # Sample new target net noise
      pns = self.target_net(next_states)  # Probabilities p(s_t+n, ·; θtarget)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      # Compute Tz (Bellman operator T applied to z)
      Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
      Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  # Clamp between supported values
      # Compute L2 projection of Tz onto fixed support z
      b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
      l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
      # Fix disappearing probability mass when l = b = u (b is int)
      l[(u > 0) * (l == u)] -= 1
      u[(l < (self.atoms - 1)) * (l == u)] += 1

      # Distribute probability of Tz
      m = states.new_zeros(self.batch_size, self.atoms)
      offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
      m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
      m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    loss = -torch.sum(m * log_ps_a, 1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
    self.online_net.zero_grad()
    (weights * loss).mean().backward()  # Backpropagate importance-weighted minibatch loss
    clip_grad_norm_(self.online_net.parameters(), self.norm_clip)  # Clip gradients by L2 norm
    self.optimiser.step()

    mem.update_priorities(idxs, loss.detach().cpu().numpy())  # Update priorities of sampled transitions

  def update_target_net(self):
    self.target_net.load_state_dict(self.online_net.state_dict())

  # Save model parameters on current device (don't move model between devices)
  def save(self, path, name='model.pth'):
    torch.save(self.online_net.state_dict(), os.path.join(path, name))

  # Evaluates Q-value based on single state (no batch)
  def evaluate_q(self, state):
    with torch.no_grad():
      return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).max(1)[0].item()

  def train(self):
    self.online_net.train()

  def eval(self):
    self.online_net.eval()
Ejemplo n.º 6
0
def main():
    """Model training."""
    train_speakers, valid_speakers = get_valid_speakers()

    # define transforms for train & validation samples
    train_transform = Compose([Resize(760, 80), ToTensor()])

    # define datasets & loaders
    train_dataset = TrainDataset('train',
                                 train_speakers,
                                 transform=train_transform)
    valid_dataset = TrainDataset('train',
                                 valid_speakers,
                                 transform=train_transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False)

    device = get_device()
    print(f'Selected device: {device}')

    model = torch.hub.load('huawei-noah/ghostnet',
                           'ghostnet_1x',
                           pretrained=True)
    model.classifier = nn.Linear(in_features=1280, out_features=1, bias=True)

    net = model
    net.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = AdaBelief(net.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.2,
                                                     patience=3,
                                                     eps=1e-4,
                                                     verbose=True)

    # prepare valid target
    yvalid = get_valid_targets(valid_dataset)

    # training loop
    for epoch in range(10):
        loss_log = {'train': [], 'valid': []}
        train_loss = []

        net.train()
        for x, y in tqdm(train_loader):
            x, y = mixup(x, y, alpha=0.2)
            x, y = x.to(device), y.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = net(x)

            loss = criterion(outputs, y.unsqueeze(1))
            loss.backward()
            optimizer.step()

            # save loss
            train_loss.append(loss.item())

        # evaluate
        net.eval()
        valid_pred = torch.Tensor([]).to(device)

        for x, y in valid_loader:
            with torch.no_grad():
                x, y = x.to(device), y.to(device, dtype=torch.float32)
                ypred = net(x)
                valid_pred = torch.cat([valid_pred, ypred], 0)

        valid_pred = sigmoid(valid_pred.cpu().numpy())
        val_loss = log_loss(yvalid, valid_pred, eps=1e-7)
        val_acc = (yvalid == (valid_pred > 0.5).astype(int).flatten()).mean()
        tqdm.write(
            f'Epoch {epoch} train_loss={np.mean(train_loss):.4f}; val_loss={val_loss:.4f}; val_acc={val_acc:.4f}'
        )

        loss_log['train'].append(np.mean(train_loss))
        loss_log['valid'].append(val_loss)
        scheduler.step(loss_log['valid'][-1])

    torch.save(net.state_dict(), 'ghostnet_model.pt')
    print('Training is complete.')