예제 #1
0
 def step(self, closure=None):
     """Gradient clipping aware step()."""
     if self.gclip is not None and self.gclip > 0:
         clip_grad_norm_(self.params, self.gclip)
     self.optim.step(closure)
예제 #2
0
def run(args):
    # Check cuda device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Data
    if hps.bucket:
        dataset = LJSpeech_Dataset(meta_file=hps.meta_path, wav_dir=hps.wav_dir, batch_size=hps.batch_size, do_bucket=True, bucket_size=20)
        loader = DataLoader(
            dataset, 
            batch_size=1,
            shuffle=True,
            num_workers=4)
    else:
        dataset = LJSpeech_Dataset(meta_file=hps.meta_path, wav_dir=hps.wav_dir)
        loader = DataLoader(
            dataset,
            batch_size=hps.batch_size,
            shuffle=True,
            num_workers=4,
            drop_last=True,
            collate_fn=collate_fn)

    # Network
    model = Tacotron()
    criterion = nn.L1Loss()
    if args.cuda:
        model = nn.DataParallel(model.to(device))
        criterion = criterion.to(device)
    # The learning rate scheduling mechanism in "Attention is all you need" 
    lr_lambda = lambda step: hps.warmup_step ** 0.5 * min((step+1) * (hps.warmup_step ** -1.5), (step+1) ** -0.5)
    optimizer = optim.Adam(model.parameters(), lr=hps.lr)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        
    step = 1
    epoch = 1
    # Load model
    if args.ckpt:
        ckpt = load(args.ckpt)
        step = ckpt['step']
        epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, 
            lr_lambda, 
            last_epoch=step)

    if args.eval:
        # Evaluation
        model.eval()
        with torch.no_grad():
			# Preprocessing eval texts
            print('Start generating evaluation speeches...')
            n_eval = len(hps.eval_texts)
            for i in range(n_eval):
                sys.stdout.write('\rProgress: {}/{}'.format(i+1, n_eval))
                sys.stdout.flush()
                text = hps.eval_texts[i]
                text = text_normalize(text)
                txt_id = sent2idx(text) + [hps.char_set.find('E')]
                GO_frame = torch.zeros(1, 1, hps.n_mels)

                # Shape: (1, seq_length)
                txt = torch.LongTensor(txt_id).unsqueeze(0)
                if args.cuda:
                    GO_frame = GO_frame.cuda()
                    txt = txt.cuda()
                _batch = model(text=txt, frames=GO_frame)
                mel = _batch['mel'][0]
                mag = _batch['mag'][0]
                attn = _batch['attn'][0]
                if args.cuda:
               	    mel = mel.cpu()
                    mag = mag.cpu()
                    attn = attn.cpu()
                mel = mel.numpy()
                mag = mag.numpy()
                attn = attn.numpy()

                wav = mag2wav(mag)
                save_alignment(attn, step, 'eval/plots/attn_{}.png'.format(text))
                save_spectrogram(mag, 'eval/plots/spectrogram_[{}].png'.format(text))
                save_wav(wav, 'eval/results/wav_{}.wav'.format(text))
            sys.stdout.write('\n')

    if args.train:
        before_load = time.time()
        # Start training
        model.train()
        while True:
            for batch in loader:
                # torch.LongTensor, (batch_size, seq_length)
                txt = batch['text']
                # torch.Tensor, (batch_size, max_time, hps.n_mels)
                mel = batch['mel']
                # torch.Tensor, (batch_size, max_time, hps.n_fft)
                mag = batch['mag']
                if hps.bucket:
                    # If bucketing, the shape will be (1, batch_size, ...)
                    txt = txt.squeeze(0)
                    mel = mel.squeeze(0)
                    mag = mag.squeeze(0)
                # GO frame
                GO_frame = torch.zeros(mel[:, :1, :].size())
                if args.cuda:
                    txt = txt.to(device)
                    mel = mel.to(device)
                    mag = mag.to(device)
                    GO_frame = GO_frame.to(device)

                # Model prediction
                decoder_input = torch.cat([GO_frame, mel[:, hps.reduction_factor::hps.reduction_factor, :]], dim=1)

                load_time = time.time() - before_load
                before_step = time.time()

                _batch = model(text=txt, frames=decoder_input)
                _mel = _batch['mel']
                _mag = _batch['mag']
                _attn = _batch['attn']

                # Optimization
                optimizer.zero_grad()
                loss_mel = criterion(_mel, mel)
                loss_mag = criterion(_mag, mag)
                loss = loss_mel + loss_mag
                loss.backward()
                # Gradient clipping
                total_norm = clip_grad_norm_(model.parameters(), max_norm=hps.clip_norm)
                # Apply gradient
                optimizer.step()
                # Adjust learning rate
                scheduler.step()
                process_time = time.time() - before_step 
                if step % hps.log_every_step == 0:
                    lr_curr = optimizer.param_groups[0]['lr']
                    log = '[{}-{}] loss: {:.3f}, grad: {:.3f}, lr: {:.3e}, time: {:.2f} + {:.2f} sec'.format(epoch, step, loss.item(), total_norm, lr_curr, load_time, process_time)
                    print(log)
                if step % hps.save_model_every_step == 0:
                    save(filepath='tmp/ckpt/ckpt_{}.pth.tar'.format(step),
                         model=model.state_dict(),
                         optimizer=optimizer.state_dict(),
                         step=step, 
                         epoch=epoch)

                if step % hps.save_result_every_step == 0:
                    sample_idx = random.randint(0, hps.batch_size-1)
                    attn_sample = _attn[sample_idx].detach().cpu().numpy()
                    mag_sample = _mag[sample_idx].detach().cpu().numpy()
                    wav_sample = mag2wav(mag_sample)
                    # Save results
                    save_alignment(attn_sample, step, 'tmp/plots/attn_{}.png'.format(step))
                    save_spectrogram(mag_sample, 'tmp/plots/spectrogram_{}.png'.format(step))
                    save_wav(wav_sample, 'tmp/results/wav_{}.wav'.format(step))
                before_load = time.time()
                step += 1
            epoch += 1
    def run(self, mt_loader, epoch=None, is_training=True):

        # always transductive
        self._algorithm._model.train()

        # loaders and iterators
        mt_iterator = tqdm(enumerate(mt_loader, start=1),
                        leave=False, file=src.logger.stdout, position=0)
        
        # metrics aggregation
        aggregate = defaultdict(list)
        
        # constants
        n_way = mt_loader.n_way
        n_shot = mt_loader.n_shot
        n_query = mt_loader.n_query
        mt_batch_sz = mt_loader.batch_size        
        randomize_query = mt_loader.randomize_query
        print(f"n_way: {n_way}, n_shot: {n_shot}, n_query: {n_query} mt_batch_sz: {mt_batch_sz} randomize_query: {randomize_query}")

        # iterating over tasks
        for i, mt_batch in mt_iterator:

            # set zero grad
            if is_training:
                self._optimizer.zero_grad()
                # global iterator count
                self._global_iteration += 1
            
            analysis = (i % self._log_interval == 0)

            '''
            # randperm
            if randomize_query and is_training:
                rp = np.random.permutation(mgr_n_query * n_way)[:n_query * n_way]
            else:
                rp = None 

            # meta-learning data creation
            mt_batch_x, mt_batch_y = mt_batch
            mt_batch_y = mt_batch_y - self._label_offset
            original_shape = mt_batch_x.shape
            assert len(original_shape) == 5
            # (batch_sz*n_way, n_shot+n_query, channels , height , width)
            mt_batch_x = mt_batch_x.reshape(mt_batch_sz, n_way, *original_shape[-4:])
            # (batch_sz, n_way, n_shot+n_query, channels , height , width)
            shots_x = mt_batch_x[:, :, :n_shot, :, :, :]
            # (batch_sz, n_way, n_shot, channels , height , width)
            if rp is None:
                query_x = mt_batch_x[:, :, n_shot:, :, :, :]
            else:
                query_x = []
                for c in range(n_way):
                    indices = rp[(rp>=(c*2*n_query)) & (rp<((c+1)*2*n_query))] - (c*2*n_query)
                    query_x.append(mt_batch_x[:, c, n_shot + indices, :, :, :])
                query_x = torch.cat(query_x, dim=1)
            # (batch_sz, n_way, n_query, channels , height , width)
            shots_x = shots_x.reshape(mt_batch_sz, -1, *original_shape[-3:])
            # (batch_sz, n_way*n_shot, channels , height , width)
            query_x = query_x.reshape(mt_batch_sz, -1, *original_shape[-3:])
            # (batch_sz, n_way*n_query, channels , height , width)
            shots_y, query_y = get_labels(mt_batch_y, n_way=n_way, 
                n_shot=n_shot, n_query=n_query, batch_sz=mt_batch_sz, rp=rp)
            '''

            shots_x, shots_y, query_x, query_y = mt_batch

            assert shots_x.shape[0:2] == (mt_batch_sz, n_way*n_shot)
            assert query_x.shape[0:2] == (mt_batch_sz, n_way*n_query)
            assert shots_y.shape == (mt_batch_sz, n_way*n_shot)
            assert query_y.shape == (mt_batch_sz, n_way*n_query)

            # to cuda
            shots_x = shots_x.cuda()
            query_x = query_x.cuda()
            shots_y = shots_y.cuda()
            query_y = query_y.cuda()
            
            for task_id in range(mt_batch_sz):
                # compute outer gradients and populate model grad with it
                # so that we can directly call optimizer.step()
                measurements_trajectory = self._algorithm.inner_loop_adapt(
                    query=query_x[task_id:task_id+1], 
                    query_labels=query_y[task_id:task_id+1], 
                    support=shots_x[task_id:task_id+1],  
                    support_labels=shots_y[task_id:task_id+1],
                    n_way=n_way, n_shot=n_shot, n_query=n_query,
                    num_updates_inner=self._num_updates_inner_train\
                         if is_training else self._num_updates_inner_val)

                # metrics accumulation
                for k in measurements_trajectory:
                    aggregate[k].append(measurements_trajectory[k][-1])

            # optimizer step
            if is_training:
                for param in self._algorithm._model.parameters():
                    param.grad /= mt_batch_sz
                if self._grad_clip > 0.:
                    clip_grad_norm_(self._algorithm._model.parameters(), 
                        max_norm=self._grad_clip, norm_type='inf')
                self._optimizer.step()

            # logging
            if analysis and is_training:
                metrics = {}
                for name, values in aggregate.items():
                    metrics[name] = np.mean(values)
                self.log_output(epoch, i, metrics)
                aggregate = defaultdict(list)    

        # save model and log tboard for eval
        if is_training and self._save_folder is not None:
            save_name = "chkpt_{0:03d}.pt".format(epoch)
            save_path = os.path.join(self._save_folder, save_name)
            with open(save_path, 'wb') as f:
                torch.save({'model': self._algorithm._model.state_dict(),
                           'optimizer': self._optimizer}, f)

        results = {
            'train_loss_trajectory': {
                'loss': np.mean(aggregate['loss']), 
                'accu': np.mean(aggregate['accu']),
            },
            'test_loss_after': {
                'loss': np.mean(aggregate['mt_outer_loss']),
                'accu': np.mean(aggregate['mt_outer_accu']),
            }
        }
        mean, i95 = (np.mean(aggregate['mt_outer_accu']), 
            1.96 * np.std(aggregate['mt_outer_accu']) / np.sqrt(len(aggregate['mt_outer_accu'])))
        results['val_task_acc'] = "{:.2f} ± {:.2f} %".format(mean, i95) 
    
        return results
예제 #4
0
 def clip_grads(self, params):
     clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, params),
                               **self.grad_clip)
예제 #5
0
 def clip_grads(self, params):
     params = list(
         filter(lambda p: p.requires_grad and p.grad is not None, params))
     if len(params) > 0:
         return clip_grad.clip_grad_norm_(params, **self.grad_clip)
예제 #6
0
                    Normal(
                        means, torch.exp(logstd.unsqueeze(0).expand(batch_size, -1))
                    ),
                    1,
                )
                log_probs = dist.log_prob(actions.squeeze())

                values = value_fn(states)
                clipped_values = old_values + torch.clamp(
                    values - old_values, -args.clip_range, args.clip_range
                )
                l_vf1 = (values - targets).pow(2)
                l_vf2 = (clipped_values - targets).pow(2)
                value_fn_loss = 0.5 * torch.max(l_vf1, l_vf2).mean()
                value_fn_loss.backward()
                clip_grad_norm_(value_fn.parameters(), args.max_grad_norm)
                value_fn_opt.step()
                value_fn_opt.zero_grad()

                k = logstd.shape[0]
                entropy = (k / 2) * (1 + math.log(2 * math.pi)) + 0.5 * torch.log(
                    torch.exp(logstd).pow(2).prod()
                )
                advs = ((advs - advs.mean()) / (advs.std() + 1e-8)).squeeze()

                prob_ratio = torch.exp(log_probs - old_log_probs)
                clipped_ratio = torch.clamp(
                    prob_ratio, 1 - args.clip_range, 1 + args.clip_range
                )
                l_cpi = prob_ratio * advs
                l_clip = clipped_ratio * advs
예제 #7
0
def main():
    # Hyper Parameters

    opt = opts.parse_opt()

    device_id = opt.gpuid
    device_count = len(str(device_id).split(","))
    #assert device_count == 1 or device_count == 2
    print("use GPU:", device_id, "GPUs_count", device_count, flush=True)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id)
    device_id = 0
    torch.cuda.set_device(0)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)
    print("Vocab loaded, size:{}".format(len(vocab)))

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)
    print('Dataloader done.')

    # Construct the model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)
    print('Model Defined.')

    # Loss and Optimizer
    criterion = ContrastiveLoss(opt=opt,
                                margin=opt.margin,
                                max_violation=opt.max_violation)
    mse_criterion = nn.MSELoss(reduction="batchmean")
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)

    # optionally resume from a checkpoint
    if not os.path.exists(opt.model_name):
        os.makedirs(opt.model_name)
    start_epoch = 0
    best_rsum = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # evalrank(model.module, val_loader, opt)

    print(opt, flush=True)

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name)
        log_file = os.path.join(opt.logger_name, "performance.log")
        logging_func(log_file, message)
        print("model name: ", opt.model_name, flush=True)
        adjust_learning_rate(opt, optimizer, epoch)
        run_time = 0
        for i, (images, captions, lengths, masks, ids,
                _) in enumerate(train_loader):
            start_time = time.time()
            model.train()

            optimizer.zero_grad()

            if device_count != 1:
                images = images.repeat(device_count, 1, 1)

            score = model(images, captions, lengths, masks, ids)
            loss = criterion(score)

            loss.backward()
            if opt.grad_clip > 0:
                clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
            run_time += time.time() - start_time
            # validate at every val_step
            if i % 100 == 0:
                log = "epoch: %d; batch: %d/%d; loss: %.4f; time: %.4f" % (
                    epoch, i, len(train_loader), loss.data.item(),
                    run_time / 100)
                print(log, flush=True)
                run_time = 0
            # if (i + 1) % opt.val_step == 0:
            #     evalrank(model.module, val_loader, opt)

        print("-------- performance at epoch: %d --------" % (epoch))
        # evaluate on validation set
        rsum = evalrank(model.module, val_loader, opt)
        #rsum = -100
        filename = 'model_' + str(epoch) + '.pth.tar'
        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
            },
            is_best,
            filename=filename,
            prefix=opt.model_name + '/')
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
            },
            False,
            filename=filename,
            prefix=opt.model_name + '/')
예제 #8
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()
        metrics_with_states: List[Tuple] = [
            (metric, {}) for metric in self.training_metrics
        ]
        self._last_train_log_step = 0

        log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        log_prefix += 'Training'

        with torch.enable_grad():
            for i in range(self.iter_per_step):
                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):
                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)
                    batch = self._batch_to_device(batch)

                    # Compute loss
                    _, _, loss = self._compute_batch(batch,
                                                     metrics_with_states)
                    accumulated_loss += loss.item() / self.batches_per_iter
                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{log_prefix}/Loss', accumulated_loss, global_step)
                log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm,
                    global_step)
                log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm,
                    global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    learning_rate = self.iter_scheduler.get_lr()[
                        0]  # type: ignore
                    log(f'{log_prefix}/LR', learning_rate, global_step)
                    self.iter_scheduler.step()  # type: ignore

                # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
                # logging train metrics
                if self.extra_training_metrics_log_interval > self._last_train_log_step:
                    self._log_metrics(log_prefix, metrics_with_states,
                                      global_step)
                    self._last_train_log_step = i
            if self._last_train_log_step != i:
                # log again at end of step, if not logged at the end of
                # step before
                self._log_metrics(log_prefix, metrics_with_states, global_step)
예제 #9
0
            #Print & log
            per_item_loss = t_loss.item() / batch_size
            train_ep_loss += t_loss.item()
            if total_steps % args.log_iter == 0:
                figure = draw_rec(dotprod.view(-1, 1), labels.view(-1, 1))
                writer.add_figure('heatmap',
                                  figure,
                                  global_step=total_steps,
                                  close=True)
                writer.add_scalar('batchLoss/train', t_loss.item(),
                                  total_steps)
                print('epoch {}, opt. step n°{}, loss per it. {:.2f}'.format(
                    epoch, total_steps, per_item_loss))

            del (t_loss)
            clip.clip_grad_norm_(model.parameters(), args.clip_norm)
            optimizer.step()

            # Annealing  LR
            if total_steps % args.anneal_iter == 0:
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])

            #Saving
            if total_steps % args.save_iter == 0:
                torch.save(model.state_dict(),
                           f"{args.save_path[:-4]}_iter_{total_steps}.pth")

        # Epoch logging
        writer.add_scalar('epochLoss/train', train_ep_loss, epoch)
        print(f'Epoch {epoch}, total loss : {train_ep_loss}')
예제 #10
0
n_iter = len(trainset)
n_epoch = 10
total_iter = n_iter * n_epoch
train_acc = []
valid_acc = []
for epoch in range(n_epoch):
  # training
  acc = []
  model.train()
  for data in train_loader:
    loss = model(*[d.cuda() for d in data])
    optim.zero_grad()
    loss.backward()
    acc.append(model.acc)
    norm = clip_grad_norm_(model.parameters(), 10.0)
    optim.step()
  train_acc.append(numpy.mean(acc))

  # validation
  acc = []
  model.eval()
  for data in dev_loader:
    model(*[d.cuda() for d in data])
    acc.append(model.acc)
  valid_acc.append(numpy.mean(acc))
  print(f"epoch: {epoch}, train acc: {train_acc[-1]:.3f}, dev acc: {valid_acc[-1]:.3f}")

import matplotlib.pyplot as plt

plt.plot(range(len(train_acc)), train_acc, label="train acc")
예제 #11
0
    def train_epoch(self,
                    model: nn.Module,
                    train_loader: DataLoader,
                    val_clean_loader: DataLoader,
                    val_triggered_loader: DataLoader,
                    epoch_num: int,
                    use_amp: bool = False):
        """
        Runs one epoch of training on the specified model

        :param model: the model to train for one epoch
        :param train_loader: a DataLoader object pointing to the training dataset
        :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean
        :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered
        :param epoch_num: the epoch number that is being trained
        :param use_amp: if True use automated mixed precision for FP16 training.
        :return: a list of statistics for batches where statistics were computed
        """

        pid = os.getpid()
        train_dataset_len = len(train_loader.dataset)
        loop = tqdm(
            train_loader,
            disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar)

        scaler = None
        if use_amp:
            scaler = torch.cuda.amp.GradScaler()

        train_n_correct, train_n_total = None, None
        sum_batchmean_train_loss = 0
        running_train_acc = 0
        num_batches = len(train_loader)
        model.train()
        for batch_idx, (x, y_truth) in enumerate(loop):
            x = x.to(self.device)
            y_truth = y_truth.to(self.device)

            # put network into training mode & zero out previous gradient computations
            self.optimizer.zero_grad()

            # get predictions based on input & weights learned so far
            if use_amp:
                with torch.cuda.amp.autocast():
                    y_hat = model(x)
                    # compute metrics
                    batch_train_loss = self._eval_loss_function(y_hat, y_truth)
            else:
                y_hat = model(x)
                # compute metrics
                batch_train_loss = self._eval_loss_function(y_hat, y_truth)

            sum_batchmean_train_loss += batch_train_loss.item()

            running_train_acc, train_n_total, train_n_correct = _running_eval_acc(
                y_hat,
                y_truth,
                n_total=train_n_total,
                n_correct=train_n_correct,
                soft_to_hard_fn=self.soft_to_hard_fn,
                soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs)

            if np.isnan(sum_batchmean_train_loss) or np.isnan(
                    running_train_acc):
                _save_nandata(x, y_hat, y_truth, batch_train_loss,
                              sum_batchmean_train_loss, running_train_acc,
                              train_n_total, train_n_correct, model)

            # compute gradient
            if use_amp:
                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
                # Backward passes under autocast are not recommended.
                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
                scaler.scale(batch_train_loss).backward()
            else:
                if np.isnan(sum_batchmean_train_loss) or np.isnan(
                        running_train_acc):
                    _save_nandata(x, y_hat, y_truth, batch_train_loss,
                                  sum_batchmean_train_loss, running_train_acc,
                                  train_n_total, train_n_correct, model)

                batch_train_loss.backward()

            # perform gradient clipping if configured
            if self.optimizer_cfg.training_cfg.clip_grad:
                if use_amp:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(self.optimizer)

                if self.optimizer_cfg.training_cfg.clip_type == 'norm':
                    # clip_grad_norm_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_norm_(
                        model.parameters(),
                        self.optimizer_cfg.training_cfg.clip_val,
                        **self.optimizer_cfg.training_cfg.clip_kwargs)
                elif self.optimizer_cfg.training_cfg.clip_type == 'val':
                    # clip_grad_val_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_value_(
                        model.parameters(),
                        self.optimizer_cfg.training_cfg.clip_val)
                else:
                    msg = "Unknown clipping type for gradient clipping!"
                    logger.error(msg)
                    raise ValueError(msg)

            if use_amp:
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(self.optimizer)
                # Updates the scale for next iteration.
                scaler.update()
            else:
                self.optimizer.step()

            loop.set_description('Epoch {}/{}'.format(epoch_num + 1,
                                                      self.num_epochs))
            loop.set_postfix(avg_train_loss=batch_train_loss.item())

            # report batch statistics to tensorflow
            if self.tb_writer:
                try:
                    batch_num = int(epoch_num * num_batches + batch_idx)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-train_loss',
                        batch_train_loss.item(),
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-running_train_acc',
                        running_train_acc,
                        global_step=batch_num)
                except:
                    # TODO: catch specific expcetions
                    pass

            if batch_idx % self.num_batches_per_logmsg == 0:
                logger.info(
                    '{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'
                    .format(pid, epoch_num, batch_idx * len(x),
                            train_dataset_len, 100. * batch_idx / num_batches,
                            batch_train_loss.item(), running_train_acc))

        train_stats = EpochTrainStatistics(
            running_train_acc, sum_batchmean_train_loss / float(num_batches))

        # if we have validation data, we compute on the validation dataset
        num_val_batches_clean = len(val_clean_loader)
        if num_val_batches_clean > 0:
            logger.info('Running Validation on Clean Data')
            running_val_clean_acc, _, _, val_clean_loss = \
                _eval_acc(val_clean_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info("No dataset computed for validation on clean dataset!")
            running_val_clean_acc = None
            val_clean_loss = None

        num_val_batches_triggered = len(val_triggered_loader)
        if num_val_batches_triggered > 0:
            logger.info('Running Validation on Triggered Data')
            running_val_triggered_acc, _, _, val_triggered_loss = \
                _eval_acc(val_triggered_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info(
                "No dataset computed for validation on triggered dataset!")
            running_val_triggered_acc = None
            val_triggered_loss = None

        validation_stats = EpochValidationStatistics(
            running_val_clean_acc, val_clean_loss, running_val_triggered_acc,
            val_triggered_loss)
        if num_val_batches_clean > 0:
            logger.info(
                '{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}'
                .format(pid, epoch_num, val_clean_loss, running_val_clean_acc))
        if num_val_batches_triggered > 0:
            logger.info(
                '{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}'
                .format(pid, epoch_num, val_triggered_loss,
                        running_val_triggered_acc))

        if self.tb_writer:
            try:
                batch_num = int((epoch_num + 1) * num_batches)
                if num_val_batches_clean > 0:
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-clean-val-loss',
                        val_clean_loss,
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-clean-val_acc',
                        running_val_clean_acc,
                        global_step=batch_num)
                if num_val_batches_triggered > 0:
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-triggered-val-loss',
                        val_triggered_loss,
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-triggered-val_acc',
                        running_val_triggered_acc,
                        global_step=batch_num)
            except:
                pass

        # update the lr-scheduler if necessary
        if self.lr_scheduler is not None:
            if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None:
                self.lr_scheduler.step()
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower(
            ) == 'val_acc':
                val_acc = validation_stats.get_val_acc()
                if val_acc is not None:
                    self.lr_scheduler.step(val_acc)
                else:
                    msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower(
            ) == 'val_loss':
                val_loss = validation_stats.get_val_loss()
                if val_loss is not None:
                    self.lr_scheduler.step(val_loss)
                else:
                    msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            else:
                msg = "Unknown mode for calling lr_scheduler!"
                logger.error(msg)
                raise ValueError(msg)

        return train_stats, validation_stats
예제 #12
0
 def step(self, *args, **kwargs):
     assert (self.clip_value is not None)
     clip_grad_norm_(self._parameters, self.clip_value)
     super().step(*args, **kwargs)
예제 #13
0
파일: agent.py 프로젝트: pybnen/conbas
    def update(self, discount: float, replay_memory: PrioritizedReplayMemory,
               loss_fn: Callable[[torch.Tensor, torch.Tensor],
                                 torch.Tensor], optimizer: optim.Optimizer,
               clip_grad_norm: float):  # -> Tuple[Number, Number]:
        assert not self.lstm_dqn_target.training
        assert self.lstm_dqn.training

        transitions, weights, indices = replay_memory.sample()

        # This is a neat trick to convert a batch transitions into one
        # transition that contains in each attribute a batch of that attribute,
        # found here: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
        # and explained in more detail here: https://stackoverflow.com/questions/19339/transpose-unzip-function-inverse-of-zip/19343#19343
        batch = TransitionBatch(*zip(*transitions))  # type: ignore

        # create tensors for update
        input_tensor, input_lengths = self.pad_input_ids(batch.observation)
        command_indices = torch.stack(batch.command_index,
                                      dim=0).to(self.device)
        rewards = torch.tensor(batch.reward,
                               dtype=torch.float32,
                               device=self.device)
        next_input_tensor, next_input_lengths = self.pad_input_ids(
            batch.next_observation)
        non_terminal_mask = 1.0 - torch.tensor(
            batch.done, dtype=torch.float32, device=self.device)
        weights = torch.from_numpy(weights).to(device=self.device)

        # q_values from policy network, Q(obs, a, phi)
        q_values = self.q_values(
            input_tensor, input_lengths, self.lstm_dqn).gather(
                dim=1, index=command_indices.unsqueeze(-1)).squeeze(-1)
        assert q_values.requires_grad

        # no need to build a computation graph here
        with torch.no_grad():
            # argmax_a Q(next_obs, a, phi)
            _, argmax_a = self.q_values(next_input_tensor, next_input_lengths,
                                        self.lstm_dqn).max(dim=1)
            # Q(next_obs, argmax_a Q(next_obs, a, phi), phi_minus)
            next_q_values = self.q_values(
                next_input_tensor, next_input_lengths,
                self.lstm_dqn_target).gather(
                    dim=1, index=argmax_a.unsqueeze(-1)).squeeze(-1)
            assert not next_q_values.requires_grad
            # target = reward + discount * Q(next_obs, argmax_a Q(next_obs, a, phi), phi_minus) * non_terminal_mask
            target = rewards + non_terminal_mask * discount * next_q_values.detach(
            )

        loss = loss_fn(q_values, target) * weights
        priorities = loss.detach().cpu().numpy()
        loss = loss.mean()

        # update step
        optimizer.zero_grad()
        loss.backward(retain_graph=False)
        total_norm = clip_grad_norm_(self.lstm_dqn.parameters(),
                                     clip_grad_norm)
        optimizer.step()

        # update priorities
        replay_memory.update_priorities(indices, priorities)

        # in server torch version total_norm is float
        if type(total_norm) == torch.Tensor:
            total_norm = total_norm.detach().cpu().item()

        return loss.detach().cpu().item(), total_norm
예제 #14
0
def train_loop(
    args,
    V,
    iter,
    model,
    parameters,
    optimizer,
    scheduler,
    valid_iter=None,
    verbose=False,
):
    global WANDB_STEP

    noise_scales = np.linspace(1, 0, args.noise_anneal_steps)
    total_ll = 0
    total_elbo = 0
    n = 0
    # check is performed at end of epoch outside loop as well
    checkpoint = len(iter) // (args.num_checks - 1)
    with th.enable_grad():
        lpz = None
        last_states = None
        for i, batch in enumerate(iter):
            model.train(True)
            WANDB_STEP += 1
            optimizer.zero_grad()

            text = batch.textp1 if "lstm" in args.model else batch.text
            if args.iterator == "bucket":
                lpz = None
                last_states = None

            mask, lengths, n_tokens = get_mask_lengths(text, V)
            if model.timing:
                start_forward = timep.time()

            # check if iterator == bptt
            losses, lpz, last_states = model.score(text,
                                                   lpz=lpz,
                                                   last_states=last_states,
                                                   mask=mask,
                                                   lengths=lengths)

            if model.timing:
                print(f"forward time: {timep.time() - start_forward}")
            total_ll += losses.evidence.detach()
            if losses.elbo is not None:
                total_elbo += losses.elbo.detach()
            n += n_tokens

            loss = -losses.loss / n_tokens
            if model.timing:
                start_backward = timep.time()
            loss.backward()
            if model.timing:
                print(f"backward time: {timep.time() - start_backward}")
            clip_grad_norm_(parameters, args.clip)
            if args.schedule not in valid_schedules:
                # sched before opt since we want step = 1?
                # this is how huggingface does it
                scheduler.step()
            optimizer.step()
            #import pdb; pdb.set_trace()
            #wandb.log({
            #"running_training_loss": total_ll / n,
            #"running_training_ppl": math.exp(min(-total_ll / n, 700)),
            #}, step=WANDB_STEP)

            if verbose and i % args.report_every == args.report_every - 1:
                report(
                    Pack(evidence=total_ll, elbo=total_elbo),
                    n,
                    f"Train batch {i}",
                )

            if valid_iter is not None and i % checkpoint == checkpoint - 1:
                v_start_time = time.time()
                #eval_fn = cached_eval_loop if args.model == "mshmm" else eval_loop
                #valid_losses, valid_n  = eval_loop(
                #valid_losses, valid_n  = cached_eval_loop(
                if args.model == "mshmm" or args.model == "factoredhmm":
                    if args.num_classes > 2**15:
                        eval_fn = mixed_cached_eval_loop
                    else:
                        eval_fn = cached_eval_loop
                elif args.model == "hmm":
                    eval_fn = cached_eval_loop
                else:
                    eval_fn = eval_loop
                valid_losses, valid_n = eval_fn(
                    args,
                    V,
                    valid_iter,
                    model,
                )
                report(valid_losses, valid_n, "Valid eval", v_start_time)
                #wandb.log({
                #"valid_loss": valid_losses.evidence / valid_n,
                #"valid_ppl": math.exp(-valid_losses.evidence / valid_n),
                #}, step=WANDB_STEP)

                update_best_valid(valid_losses, valid_n, model, optimizer,
                                  scheduler, args.name)

                #wandb.log({
                #"lr": optimizer.param_groups[0]["lr"],
                #}, step=WANDB_STEP)
                scheduler.step(valid_losses.evidence)

                # remove this later?
                if args.log_counts > 0 and args.keep_counts > 0:
                    # TODO: FACTOR OUT
                    counts = (model.counts /
                              model.counts.sum(0, keepdim=True))[:, 4:]
                    c, v = counts.shape
                    #cg4 = counts > 1e-4
                    #cg3 = counts > 1e-3
                    cg2 = counts > 1e-2

                    #wandb.log({
                    #"avgcounts@1e-4": cg4.sum().item() / float(v),
                    #"avgcounts@1e-3": cg3.sum().item() / float(v),
                    #"avgcounts@1e-2": cg2.sum().item() / float(v),
                    #"maxcounts@1e-4": cg4.sum(0).max().item() / float(v),
                    #"maxcounts@1e-3": cg3.sum(0).max().item() / float(v),
                    #"maxcounts@1e-2": cg2.sum(0).max().item(),
                    #"mincounts@1e-4": cg4.sum(0).min().item() / float(v),
                    #"mincounts@1e-3": cg3.sum(0).min().item() / float(v),
                    #"mincounts@1e-2": cg2.sum(0).min().item(),
                    #"maxcounts": counts.sum(0).max().item(),
                    #"mincounts": counts.sum(0).min().item(),
                    #}, step=WANDB_STEP)
                    del cg2
                    del counts

    return Pack(evidence=total_ll, elbo=total_elbo), n
예제 #15
0
    def step(self,
             feed_dict,
             grad_clip=0.,
             reduce_func=default_reduce_func,
             cast_tensor=False,
             measure_time=False):
        if hasattr(self.model, 'train_step'):
            return self.model.train_step(self.optimizer, feed_dict)

        assert self._model.training, 'Step a evaluation-mode model.'
        extra = dict()

        self.trigger_event('step:before', self)

        if cast_tensor:
            feed_dict = as_tensor(feed_dict)

        if measure_time:
            end_time = cuda_time()

        self.trigger_event('forward:before', self, feed_dict)
        loss, monitors, output_dict = self._model(feed_dict)
        self.trigger_event('forward:after', self, feed_dict, loss, monitors,
                           output_dict)

        if measure_time:
            extra['time/forward'] = cuda_time() - end_time
            end_time = cuda_time(False)

        loss = reduce_func('loss', loss)
        monitors = {k: reduce_func(k, v) for k, v in monitors.items()}

        loss_f = as_float(loss)
        monitors_f = as_float(monitors)

        if measure_time:
            extra['time/loss'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self._optimizer.zero_grad()
        self.trigger_event('backward:before', self, feed_dict, loss, monitors,
                           output_dict)
        if loss.requires_grad:
            loss.backward()
            if grad_clip > 0:
                from torch.nn.utils.clip_grad import clip_grad_norm_
                clip_grad_norm_(self.model.parameters(), grad_clip)

        if measure_time:
            extra['time/backward'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('backward:after', self, feed_dict, loss, monitors,
                           output_dict)
        if loss.requires_grad:
            self._optimizer.step()

        if measure_time:
            extra['time/optimize'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('step:after', self)

        return loss_f, monitors_f, output_dict, extra
예제 #16
0
def train_epoch(dataloader, model, optimizer, epoch=None):
    model.train()
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        fbank, seq_lens, tokens, language = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
        if isinstance(optimizer, dict):
            optimizer[language].zero_grad()
        else:
            optimizer.zero_grad()
        model.zero_grad()
        if args.ngpu <= 1 or args.dist_train:
            ctc_att_loss, sim_adapter_guide_loss = model(
                fbank, seq_lens, tokens,
                language)  # .mean() # / self.accum_grad
        else:
            # apex does not support torch.nn.DataParallel
            ctc_att_loss, sim_adapter_guide_loss = (
                data_parallel(model, (fbank, seq_lens, tokens, language),
                              range(args.ngpu))  # .mean() # / self.accum_grad
            )
        loss = ctc_att_loss.mean()
        if args.sim_adapter:
            if hasattr(model, "module"):
                sim_adapter_reg_loss = model.module.get_fusion_regularization_loss(
                )
            else:
                sim_adapter_reg_loss = model.get_fusion_regularization_loss()
            loss = loss + sim_adapter_reg_loss
            stats["sim_adapter_reg_loss_lst"].append(
                sim_adapter_reg_loss.item())
            if args.guide_loss_weight > 0:
                if args.guide_loss_weight_decay_steps > 0:
                    n_batch = len(dataloader)
                    current_iter = float(batch_idx + (epoch - 1) * n_batch)
                    frac_done = 1.0 * float(
                        current_iter) / args.guide_loss_weight_decay_steps
                    current_weight = args.guide_loss_weight * max(
                        0., 1. - frac_done)
                    stats["sim_adapter_guide_loss_weight"] = current_weight
                else:
                    current_weight = args.guide_loss_weight
                sim_adapter_guide_loss = sim_adapter_guide_loss.mean()
                loss = loss + current_weight * sim_adapter_guide_loss
                stats["sim_adapter_guide_loss_lst"].append(
                    sim_adapter_guide_loss.item())

        if not hasattr(model, "module"):
            if hasattr(model, "acc") and model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if hasattr(model, "acc") and model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
        loss.backward()
        grad_norm = clip_grad_norm_(model.parameters(), args.grad_clip)
        if math.isnan(grad_norm):
            logging.warning("grad norm is nan. Do not update model.")
        else:
            if isinstance(optimizer, dict):
                optimizer[language].step()
            else:
                optimizer.step()
            stats["loss_lst"].append(loss.item())
        logging.warning(f"Training batch: {batch_idx+1}/{len(dataloader)}")
    return dict_average(stats)
            if lambda1 > 0:
                reg1_loss = calculate_l1_loss(lstm_net)

            loss = mse_criterion(
                loss_outputs, loss_targets
            ) + lambda1 * reg1_loss + ec_lambda * unsup_loss + dc_lambda * dc_unsup_loss

            avg_loss += loss
            avg_unsup_loss += unsup_loss
            avg_dc_unsup_loss += dc_unsup_loss

            batches_done += 1
            #backward prop
            loss.backward(retain_graph=False)
            if grad_clip > 0:
                clip_grad_norm_(lstm_net.parameters(), grad_clip, norm_type=2)

            #optimize
            optimizer.step()

            #zero the parameter gradients
            optimizer.zero_grad()

            #print statistics
            running_loss += loss.item()
            if verbose:
                if i % 3 == 2:
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 3))
                    running_loss = 0.0
        avg_loss = avg_loss / batches_done
예제 #18
0
def train_maml_epoch(dataloader, model, optimizer, epoch=None):
    model.train()
    stats = collections.defaultdict(list)

    for batch_idx, total_batches in enumerate(dataloader):
        i = batch_idx  # current iteration in epoch
        len_dataloader = len(dataloader)  # total iteration in epoch
        meta_iters = args.epochs * len_dataloader
        current_iter = float(i + (epoch - 1) * len_dataloader)
        frac_done = 1.0 * float(current_iter) / meta_iters
        current_outerstepsize = args.meta_lr * (1. - frac_done)

        weights_original = copy.deepcopy(model.state_dict())
        new_weights = []
        for total_batch in total_batches:  # Iter by languages
            in_batch_size = int(total_batch[0].shape[0] /
                                2)  # In-language batch size
            for meta_step in range(2):  # Meta-train & meta-valid
                if meta_step == 1:
                    last_backup = copy.deepcopy(model.state_dict())
                else:
                    last_backup = None
                batch = list(copy.deepcopy(total_batch))
                for i_batch in range(len(batch) - 1):
                    batch[i_batch] = batch[i_batch][meta_step *
                                                    in_batch_size:(1 +
                                                                   meta_step) *
                                                    in_batch_size]
                batch = tuple(batch)

                fbank, seq_lens, tokens, language = batch
                fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(
                ), tokens.cuda()
                optimizer.zero_grad()
                model.zero_grad()
                if args.ngpu <= 1 or args.dist_train:
                    loss = model(fbank, seq_lens, tokens,
                                 language).mean()  # / self.accum_grad
                else:
                    # apex does not support torch.nn.DataParallel
                    loss = (
                        data_parallel(
                            model, (fbank, seq_lens, tokens, language),
                            range(args.ngpu)).mean()  # / self.accum_grad
                    )
                # print(loss.item())
                loss.backward()
                grad_norm = clip_grad_norm_(model.parameters(), args.grad_clip)
                if math.isnan(grad_norm):
                    logging.warning("grad norm is nan. Do not update model.")
                else:
                    optimizer.step()

                if meta_step == 1:  # Record meta valid
                    if not hasattr(model, "module"):
                        if hasattr(model, "acc") and model.acc is not None:
                            stats["acc_lst"].append(model.acc)
                            model.acc = None
                    else:
                        if hasattr(model,
                                   "acc") and model.module.acc is not None:
                            stats["acc_lst"].append(model.module.acc)
                            model.module.acc = None
                    stats["loss_lst"].append(loss.item())
                    stats["meta_lr"] = current_outerstepsize
                    optimizer.zero_grad()

            for name in last_backup:
                # Compute meta-gradient
                last_backup[name] = model.state_dict(
                )[name] - last_backup[name]
            # Change back to the original parameters for the new language
            new_weights.append(
                last_backup
            )  # updates.append(subtract_vars(self._model_state.export_variables(), last_backup))
            model.load_state_dict(
                {name: weights_original[name]
                 for name in weights_original})

        ws = len(new_weights)
        # Compute average meta-gradient
        fweights = {
            name: new_weights[0][name] / float(ws)
            for name in new_weights[0]
        }
        for i in range(1, ws):
            for name in new_weights[i]:
                fweights[
                    name] = fweights[name] + new_weights[i][name] / float(ws)
        model.load_state_dict({
            name:
            weights_original[name] + (fweights[name] * current_outerstepsize)
            for name in weights_original
        })

        logging.warning(f"Training batch: {batch_idx+1}/{len(dataloader)}")
    return dict_average(stats)
예제 #19
0
def learn_return(network, optimizer, dataset, log_dir, args):
    network.train()
    #check if gpu available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Assume that we are on a CUDA machine, then this should print a CUDA device:
    print(device)
    loss_criterion = nn.CrossEntropyLoss()

    logs = [[],[],[],[]] # (losses, epoch_losses, accuracies, magnitudes)

    retsymb = '\n' if args.grid else '\r'
    debug = True
    print_interval = 100

    for epoch in range(args.num_iter):
        dloader = DataLoader(dataset, shuffle=True, pin_memory=True, num_workers=8)

        n_correct  = 0
        frames     = 0
        cum_loss   = 0.0
        cum_mag    = 0.0
        epoch_loss = 0
        start_time = time.time()
        for i, data in enumerate(dloader):
            print_epoch = i % print_interval == 0

            traj_i, traj_j, labels = data
            actions_i = [traj_i[1][0].to(device)]
            actions_j = [traj_j[1][0].to(device)]
            traj_i    = [traj_i[0][0].to(device)]
            traj_j    = [traj_j[0][0].to(device)]
            labels    = labels[0].to(device)

            frames += traj_i[0].shape[0] + traj_j[0].shape[0]

            if args.bc:
                for traj, actions in zip(traj_i+traj_j, actions_i+actions_j):
                    outputs = network.bc(traj)
                    loss = loss_criterion(outputs, actions.long().view(-1)).mean()
                    loss.backward()
                    abs = torch.Tensor([0])

            else:
                #forward + backward + optimize
                outputs, abs = network.forward(traj_i, traj_j, actions_i, actions_j, print_epoch)
                #outputs = outputs.unsqueeze(0)
                loss = loss_criterion(outputs, labels.long()).mean()
                if loss < 0.693:
                    n_correct += 1
                loss = loss + args.l1_reg * abs
                loss.backward()


            if i % args.batch_size == 0:

                '''
                print('')
                print('###########################')
                _norm = []
                for p in network.parameters():
                    if p.grad is not None:
                        _norm += [p.grad.view(-1).detach()]
                _norm = torch.cat(_norm).norm()
                print('grad norm pre-clip: {}'.format(_norm))
                '''
                

                clip_grad_norm_(network.parameters(), 10)


                '''
                _norm = []
                for p in network.parameters():
                    if p.grad is not None:
                        _norm += [p.grad.view(-1).detach()]
                _norm = torch.cat(_norm).norm()
                print('grad norm post-clip: {}'.format(_norm))
                print(outputs.detach().cpu().numpy(), labels.detach().cpu().numpy())
                '''

                optimizer.step()
                optimizer.zero_grad()

            #print stats to see if learning
            item_loss   = loss.item()
            epoch_loss += item_loss
            cum_loss   += item_loss
            cum_mag    += abs.item()
            # The printed loss may not be perfectly accurate but good enough?
            if print_epoch:
                #print(i)
                eps = print_interval / (time.time() - start_time)
                fps = frames / (time.time() - start_time)
                frames = 0
                if i > 0:
                    cum_loss = cum_loss / print_interval
                    cum_mag  = cum_mag / print_interval
                print("epoch {}:{}/{} loss {} mag {}   |   eps {} fps {}".format(epoch+1, i, len(dataset), cum_loss, cum_mag, eps, fps), end=retsymb)
                logs[0] += [cum_loss]
                logs[3] += [cum_mag]
                #print(abs_rewards)
                cum_loss = 0.0
                cum_mag  = 0.0
                start_time = time.time()
        #if debug:
        #    print('\n\n                                       ####\n')
        accuracy = n_correct / len(dloader)
        print('epoch {} average loss: {} average accuracy: {}                                  '.format(epoch+1, epoch_loss / len(dataset), accuracy))
        logs[1] += [epoch_loss / len(dataset)]
        logs[2] += [accuracy]
        #'''
        for g in optimizer.param_groups:
            g['lr'] *= 0.95
        #'''

        #print("check pointing")
        torch.save(net.state_dict(), args.model_path)

    with open(log_dir, 'wb') as f:
        pickle.dump(logs, f)
    print("finished training")
예제 #20
0
    def train_on_batch(self, batch):
        self.optim.zero_grad()

        x = batch[0].to(self.device_name)
        y = batch[1].to(self.device_name)

        good_losses = []
        good_accs = []
        bad_losses = []
        bad_accs = []

        good_idx = self.model.get_class_idx(1).tolist()
        bad_idx = self.model.get_class_idx(-1).tolist()

        num_good_points_to_use = min(len(good_idx), self.config.hp.num_good_cells_per_update)
        num_bad_points_to_use = min(len(bad_idx), self.config.hp.num_bad_cells_per_update)

        for i, j in random.sample(good_idx, num_good_points_to_use):
            preds = self.model.run_from_weights(self.model.compute_point(i,j), x)

            good_loss = self.criterion(preds, y).mean()
            good_losses.append(good_loss.item())
            good_loss /= num_good_points_to_use
            good_loss.backward() # To make the graph free

            good_accs.append((preds.argmax(dim=1) == y).float().mean().item())

        for i, j in random.sample(bad_idx, num_bad_points_to_use):
            preds = self.model.run_from_weights(self.model.compute_point(i,j), x)

            bad_loss = self.criterion(preds, y).mean()
            bad_losses.append(bad_loss.item())
            bad_loss = bad_loss.clamp(0, self.config.hp.neg_loss_clip_threshold)
            bad_loss /= num_bad_points_to_use
            bad_loss *= self.config.hp.get('negative_loss_coef', 1.)
            bad_loss *= -1 # To make it grow
            bad_loss.backward() # To make the graph free

            bad_accs.append((preds.argmax(dim=1) == y).float().mean().item())

        good_losses = np.array(good_losses)
        good_accs = np.array(good_accs)
        bad_losses = np.array(bad_losses)
        bad_accs = np.array(bad_accs)

        # Adding regularization
        if self.config.hp.parametrization_type != "up_orthogonal":
            ort = self.model.compute_ort_reg()
            norm_diff = self.model.compute_norm_reg()
            reg_loss = self.config.hp.ort_l2_coef * ort.pow(2) + self.config.hp.norm_l2_coef * norm_diff.pow(2)
            reg_loss.backward()

            self.writer.add_scalar('Reg/ort', ort.item(), self.num_iters_done)
            self.writer.add_scalar('Reg/norm_diff', norm_diff.item(), self.num_iters_done)

        clip_grad_norm_(self.model.parameters(), self.config.hp.grad_clip_threshold)
        self.optim.step()

        if not self.scheduler is None:
            self.scheduler.step()

        self.writer.add_scalar('good/train/loss', good_losses.mean().item(), self.num_iters_done)
        self.writer.add_scalar('good/train/acc', good_accs.mean().item(), self.num_iters_done)
        self.writer.add_scalar('bad/train/loss', bad_losses.mean().item(), self.num_iters_done)
        self.writer.add_scalar('bad/train/acc', bad_accs.mean().item(), self.num_iters_done)
        self.writer.add_scalar('diff/train/loss', good_losses.mean().item() - bad_losses.mean().item(), self.num_iters_done)
        self.writer.add_scalar('diff/train/acc', good_accs.mean().item() - bad_accs.mean().item(), self.num_iters_done)

        self.writer.add_scalar('Stats/lengths/right', self.model.right.norm(), self.num_iters_done)
        self.writer.add_scalar('Stats/lengths/up', self.model.up.norm(), self.num_iters_done)
        self.writer.add_scalar('Stats/grad_norms/origin', self.model.origin.grad.norm().item(), self.num_iters_done)
        self.writer.add_scalar('Stats/grad_norms/right_param', self.model.right_param.grad.norm().item(), self.num_iters_done)
        self.writer.add_scalar('Stats/grad_norms/up_param', self.model.up_param.grad.norm().item(), self.num_iters_done)
        self.writer.add_scalar('Stats/grad_norms/scaling', self.model.scaling_param.grad.norm().item(), self.num_iters_done)
        self.writer.add_scalar('Stats/scaling', self.model.scaling_param.item(), self.num_iters_done)
예제 #21
0
def train(model, loader, optimizer, criterion, scheduler, step, epoch, device,
          args):
    before_load = time.time()
    # Start training
    model.train()
    while True:
        for batch in loader:
            # torch.LongTensor, (batch_size, seq_length)
            txt = batch['text']
            # torch.Tensor, (batch_size, max_time, hps.n_mels)
            mel = batch['mel']
            # torch.Tensor, (batch_size, max_time, hps.n_fft)
            mag = batch['mag']
            # torch.LongTensor, (batch_size, )
            txt_len = batch['text_length']
            frame_len = batch['frame_length']

            if hps.bucket:
                # If bucketing, the shape will be (1, batch_size, ...)
                txt = txt.squeeze(0)
                mel = mel.squeeze(0)
                mag = mag.squeeze(0)
                txt_len = txt_len.squeeze(0)
                frame_len = frame_len.squeeze(0)
            # GO frame
            GO_frame = torch.zeros(mel[:, :1, :].size())
            if args.cuda:
                txt = txt.to(device)
                mel = mel.to(device)
                mag = mag.to(device)
                GO_frame = GO_frame.to(device)

            # Model prediction
            decoder_input = torch.cat([
                GO_frame, mel[:, hps.reduction_factor::hps.reduction_factor, :]
            ],
                                      dim=1)

            load_time = time.time() - before_load
            before_step = time.time()

            _batch = model(text=txt,
                           frames=decoder_input,
                           text_length=txt_len,
                           frame_length=frame_len)
            _mel = _batch['mel']
            _mag = _batch['mag']
            _attn = _batch['attn']

            # Optimization
            optimizer.zero_grad()
            loss_mel = criterion(_mel, mel)
            loss_mag = criterion(_mag, mag)
            loss = loss_mel + loss_mag
            loss.backward()
            # Gradient clipping
            total_norm = clip_grad_norm_(model.parameters(),
                                         max_norm=hps.clip_norm)
            # Apply gradient
            optimizer.step()
            # Adjust learning rate
            scheduler.step()
            process_time = time.time() - before_step
            if step % hps.log_every_step == 0:
                lr_curr = optimizer.param_groups[0]['lr']
                log = '[{}-{}] total_loss: {:.3f}, mel_loss: {:.3f}, mag_loss: {:.3f}, grad: {:.3f}, lr: {:.3e}, time: {:.2f} + {:.2f} sec'.format(
                    epoch, step, loss.item(), loss_mel.item(), loss_mag.item(),
                    total_norm, lr_curr, load_time, process_time)
                print(log)
            if step % hps.save_model_every_step == 0:
                save(filepath='tmp/ckpt/ckpt_{}.pth.tar'.format(step),
                     model=model.state_dict(),
                     optimizer=optimizer.state_dict(),
                     step=step,
                     epoch=epoch)

            if step % hps.save_result_every_step == 0:
                sample_idx = random.randint(0, hps.batch_size - 1)
                attn_sample = _attn[sample_idx].detach().cpu().numpy()
                mag_sample = _mag[sample_idx].detach().cpu().numpy()
                wav_sample = mag2wav(mag_sample)
                # Save results
                save_alignment(attn_sample, step,
                               'tmp/plots/attn_{}.png'.format(step))
                save_spectrogram(mag_sample,
                                 'tmp/plots/spectrogram_{}.png'.format(step))
                save_wav(wav_sample, 'tmp/results/wav_{}.wav'.format(step))
            before_load = time.time()
            step += 1
        epoch += 1
예제 #22
0
 def clip_grads(self, params):
     # operations on the gradient of parameters
     # params = self.invalid_to_zero(params)
     clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, params),
                               **self.grad_clip)
예제 #23
0
 def clip_grads(self, params):
     # operations on the gradient of parameters
     clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, params), **self.grad_clip)
예제 #24
0
         frame_data['gt_dets'], frame_data['gt_ids'],
         frame_data['gt_vis'])
     loss_j = model.loss(output, prev_tracked_ids,
                         output_tracked_ids, frame_data)
     losses.append(loss_j)
     # print(model._cs_pos[0].weight.mean(), 'after')
 loss, logs, logh = parse_losses(losses)
 if isinstance(loss, torch.Tensor):
     # for name, one in model.named_parameters():
     #     if one.grad is not None:
     #         print(name, one.grad.mean())
     loss.backward()
     # for name, one in model.named_parameters():
     #     if one.grad is not None:
     #         print(name, one.grad.mean())
     clip_grad_norm_(model.parameters(), 1)
     opt.step()
 # print(model._cs_pos[0].weight.mean(), 'stepped')
 lr = lr_scheduler.get_last_lr()[0]
 if i < epoch_outer or i % 10 == 1:
     print(
         'epoch %d (outer %d inner %d [%d:%d]) iter %d/%d: lr %.6f'
         % (lr_scheduler.last_epoch, eo, e, second, third, i,
            len(dl), float(lr)))
     print('\t', end='')
     for k in logs:
         print('.%s: %.6f(%.6f), ' % (k, logs[k], logh[k]),
               end='')
     print(' loss %.6f' % float(loss))
     for j in range(len(losses)):
         print('\tstep %d;' % j, end='')
예제 #25
0
    def train(self, batch, p_idx, v_idx):
        """Model training.

        Args:
            batch (dict): The dictionary of a batch of experience. For example:
                {
                    "s": the dictionary of state,
                    "a": model actions in numpy array,
                    "R": the n-step accumulated reward,
                    "s"": the dictionary of the next state,
                }
            p_idx (int): The identity of the port doing the action.
            v_idx (int): The identity of the vessel doing the action.

        Returns:
            a_loss (float): action loss.
            c_loss (float): critic loss.
            e_loss (float): entropy loss.
            tot_norm (float): the L2 norm of the gradient.

        """
        self._tot_batchs += 1
        item_a_loss, item_c_loss, item_e_loss = 0, 0, 0
        obs_batch = batch["s"]
        action_batch = batch["a"]
        return_batch = batch["R"]
        next_obs_batch = batch["s_"]

        obs_batch = gnn_union(
            obs_batch["p"], obs_batch["po"], obs_batch["pedge"], obs_batch["v"], obs_batch["vo"], obs_batch["vedge"],
            self._p2p_adj, obs_batch["ppedge"], obs_batch["mask"], self._device)
        action_batch = torch.from_numpy(action_batch).long().to(self._device)
        return_batch = torch.from_numpy(return_batch).float().to(self._device)
        next_obs_batch = gnn_union(
            next_obs_batch["p"], next_obs_batch["po"], next_obs_batch["pedge"], next_obs_batch["v"],
            next_obs_batch["vo"], next_obs_batch["vedge"], self._p2p_adj, next_obs_batch["ppedge"],
            next_obs_batch["mask"], self._device)

        # Train actor network.
        self._optimizer["a&c"].zero_grad()

        # Every port has a value.
        # values.shape: (batch, p_cnt)
        probs, values = self._model_dict["a&c"](obs_batch, a=True, p_idx=p_idx, v_idx=v_idx, c=True)
        distribution = Categorical(probs)
        log_prob = distribution.log_prob(action_batch)
        entropy_loss = distribution.entropy()

        _, values_ = self._model_dict["a&c"](next_obs_batch, c=True)
        advantage = return_batch + self._value_discount * values_.detach() - values

        if self._entropy_factor != 0:
            # actor_loss = actor_loss* torch.log(entropy_loss + np.e)
            advantage[:, p_idx] += self._entropy_factor * entropy_loss.detach()

        actor_loss = - (log_prob * torch.sum(advantage, axis=-1).detach()).mean()

        item_a_loss = actor_loss.item()
        item_e_loss = entropy_loss.mean().item()

        # Train critic network.
        critic_loss = torch.sum(advantage.pow(2), axis=1).mean()
        item_c_loss = critic_loss.item()
        # torch.nn.utils.clip_grad_norm_(self._critic_model.parameters(),0.5)
        tot_loss = 0.1 * actor_loss + critic_loss
        tot_loss.backward()
        tot_norm = clip_grad.clip_grad_norm_(self._model_dict["a&c"].parameters(), 1)
        self._optimizer["a&c"].step()
        return item_a_loss, item_c_loss, item_e_loss, float(tot_norm)
예제 #26
0
def train(net, X_train, y_train, X_test, y_test, epochs=100, lr=0.001, weight_decay=0.005, clip_val=15):
    print("\n\n********** Running training! ************\n\n")
    opt = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    sched = getLRScheduler(optimizer=opt)
    criterion = nn.CrossEntropyLoss()

    #if (train_on_gpu):
    if (torch.cuda.is_available() ):
        net.cuda()

    train_losses = []
    # results = np.empty([0, 5], dtype=np.float32)
    net.train()

    epoch_train_losses = []
    epoch_train_acc = []
    epoch_test_losses = []
    epoch_test_acc = []
    train_len = len(X_train)
    X_tr = X_train
    y_tr = y_train
    params = {
        'epochs' : [],
        'train_loss' : [],
        'test_loss' : [],
        'lr' : [],
        'train_accuracy' : [],
        'test_accuracy' : []
    }
    for epoch in range(epochs):
        train_losses = []
        step = 1

        h = net.init_hidden(batch_size)

        train_accuracy = 0
        #np.random.shuffle(X_tr)
        #np.random.shuffle(y_tr)

        while step * batch_size <= train_len:
            batch_xs = extract_batch_size(X_tr, step, batch_size)
            # batch_ys = one_hot_vector(extract_batch_size(y_train, step, batch_size))
            batch_ys = extract_batch_size(y_tr, step, batch_size)

            inputs, targets = torch.from_numpy(batch_xs), torch.from_numpy(batch_ys.flatten('F'))
            #if (train_on_gpu):
            if (torch.cuda.is_available() ):
                inputs, targets = inputs.cuda(), targets.cuda()

            h = tuple([each.data for each in h])
            opt.zero_grad()

            output = net(inputs.float(), h)
            # print("lenght of inputs is {} and target value is {}".format(inputs.size(), targets.size()))
            train_loss = criterion(output, targets.long())
            train_losses.append(train_loss.item())

            top_p, top_class = output.topk(1, dim=1)
            equals = top_class == targets.view(*top_class.shape).long()
            train_accuracy += torch.mean(equals.type(torch.FloatTensor))
            equals = top_class

            train_loss.backward()
            clip_grad.clip_grad_norm_(net.parameters(), clip_val)
            opt.step()
            step += 1

        p = opt.param_groups[0]['lr']
        params['lr'].append(p)
        params['epochs'].append(epoch)
        sched.step()
        train_loss_avg = np.mean(train_losses)
        train_accuracy_avg = train_accuracy/(step-1)
        epoch_train_losses.append(train_loss_avg)
        epoch_train_acc.append(train_accuracy_avg)
        print("Epoch: {}/{}...".format(epoch + 1, epochs),
              ' ' * 16 + "Train Loss: {:.4f}".format(train_loss_avg),
              "Train accuracy: {:.4f}...".format(train_accuracy_avg))
        test_loss, test_f1score, test_accuracy = test(net, X_test, y_test, criterion, test_batch=batch_size)
        epoch_test_losses.append(test_loss)
        epoch_test_acc.append(test_accuracy)
        if ((epoch+1) % 10 == 0):
            print("Epoch: {}/{}...".format(epoch + 1, epochs),
                  ' ' * 16 + "Test Loss: {:.4f}...".format(test_loss),
                  "Test accuracy: {:.4f}...".format(test_accuracy),
                  "Test F1: {:.4f}...".format(test_f1score))

    params['train_loss'] = epoch_train_losses
    params['test_loss'] = epoch_test_losses
    params['train_accuracy'] = epoch_train_acc
    params['test_accuracy'] = epoch_test_acc
    return params
예제 #27
0
 def step(self):
     self.max_grad_norm = 0.05
     clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
     self.optimizer.step()
예제 #28
0
 def train(self,
           optimizer,
           n_iters,
           *,
           log_interval=50,
           early_stopping_cnt=0,
           scheduler=None,
           snapshot_interval=2500):
     train_losses = deque(maxlen=self.avg_window)
     train_weights = deque(maxlen=self.avg_window)
     if self.val_loader is not None:
         best_val_loss = 10000
     step, epoch = 0, 0
     wo_improvement = 0
     self.best_performers = []
     self.logger.info("Optimizer {}".format(str(optimizer)))
     self.logger.info("Batches per epoch: {}".format(len(
         self.train_loader)))
     try:
         while step < n_iters:
             epoch += 1
             self.logger.info("=" * 20 + "Epoch {}".format(epoch) +
                              "=" * 20)
             for *input_tensors, y in self.train_loader:
                 input_tensors = [x.to(DEVICE) for x in input_tensors]
                 self.model.train()
                 y = y.to(DEVICE)
                 assert self.model.training
                 optimizer.zero_grad()
                 output = self.model(*input_tensors)
                 batch_loss = self.criterion(output[:, 0], y)
                 batch_loss.backward()
                 train_losses.append(batch_loss.data.cpu().numpy())
                 train_weights.append(y.size(0))
                 clip_grad_norm_(self.model.parameters(), self.clip_grad)
                 optimizer.step()
                 step += 1
                 if (step % log_interval == 0
                         or step % snapshot_interval == 0):
                     train_loss_avg = np.average(train_losses,
                                                 weights=train_weights)
                     self.logger.info(
                         "Step {}: train {:.6f} lr: {:.3e}".format(
                             step, train_loss_avg,
                             optimizer.param_groups[0]['lr']))
                 if self.val_loader is not None and step % snapshot_interval == 0:
                     _, loss = self.predict(self.val_loader, is_test=False)
                     loss_str = "%.6f" % loss
                     self.logger.info("Snapshot loss %s", loss_str)
                     target_path = (
                         CHECKPOINT_DIR /
                         "snapshot_{}_{}.pth".format(self.name, loss_str))
                     heapq.heappush(self.best_performers,
                                    (loss, target_path))
                     torch.save(self.model.state_dict(), target_path)
                     self.logger.info(target_path)
                     assert Path(target_path).exists()
                     if best_val_loss > loss + 1e-4:
                         self.logger.info("New low\n")
                         # self.save_state()
                         best_val_loss = loss
                         wo_improvement = 0
                     else:
                         wo_improvement += 1
                 if scheduler:
                     # old_lr = optimizer.param_groups[0]['lr']
                     scheduler.step()
                     # if old_lr != optimizer.param_groups[0]['lr']:
                     #     # Reload best checkpoint
                     #     self.restore_state()
                 if (self.val_loader is not None and early_stopping_cnt
                         and wo_improvement > early_stopping_cnt):
                     return self.best_performers
                 if step >= n_iters:
                     break
     except KeyboardInterrupt:
         pass
     self.base_steps += step
     return self.best_performers
    def run(self, mt_loader, epoch=None, is_training=True):

        if is_training:
            # this should be made to be applied on self._algorithm.train()
            self._algorithm._model.train()
        else:
            self._algorithm._model.eval()

        # loaders and iterators
        mt_iterator = tqdm(enumerate(mt_loader, start=1),
                           leave=False,
                           file=src.logger.stdout, position=0)
        
        # metrics aggregation
        aggregate = defaultdict(list)
        
        # constants
        n_way = mt_loader.n_way
        n_shot = mt_loader.n_shot
        mt_batch_sz = mt_loader.batch_size
        n_query = mt_loader.n_query
        randomize_query = mt_loader.randomize_query
        print(f"n_way: {n_way}, n_shot: {n_shot}, n_query: {n_query} mt_batch_sz: {mt_batch_sz} randomize_query: {randomize_query}")
        

        for i, mt_batch in mt_iterator:

            # global iterator count
            if is_training:
                self._global_iteration += 1

            analysis = (i % self._log_interval == 0)

            '''
            # legacy code before the data sampling code updates
            # randperm
            if randomize_query and is_training:
                rp = np.random.permutation(2 * n_query * n_way)[:n_query * n_way]
            else:
                rp = None 

            # meta-learning data creation
            mt_batch_x, mt_batch_y = mt_batch
            mt_batch_y = mt_batch_y - self._label_offset
            original_shape = mt_batch_x.shape
            assert len(original_shape) == 5
            # (batch_sz*n_way, n_shot+n_query, channels , height , width)
            mt_batch_x = mt_batch_x.reshape(mt_batch_sz, n_way, *original_shape[-4:])
            # (batch_sz, n_way, n_shot+n_query, channels , height , width)
            shots_x = mt_batch_x[:, :, :n_shot, :, :, :]
            # (batch_sz, n_way, n_shot, channels , height , width)
            if rp is None:
                query_x = mt_batch_x[:, :, n_shot:, :, :, :]
            else:
                query_x = []
                for c in range(n_way):
                    indices = rp[(rp>=(c*2*n_query)) & (rp<((c+1)*2*n_query))] - (c*2*n_query)
                    query_x.append(mt_batch_x[:, c, n_shot + indices, :, :, :])
                query_x = torch.cat(query_x, dim=1)
            # (batch_sz, n_way, n_query, channels , height , width)
            shots_x = shots_x.reshape(mt_batch_sz, -1, *original_shape[-3:])
            # (batch_sz, n_way*n_shot, channels , height , width)
            query_x = query_x.reshape(mt_batch_sz, -1, *original_shape[-3:])
            # (batch_sz, n_way*n_query, channels , height , width)
            shots_y, query_y = get_labels(mt_batch_y, n_way=n_way, 
                n_shot=n_shot, n_query=n_query, batch_sz=mt_batch_sz, rp=rp)
            '''

            shots_x, shots_y, query_x, query_y = mt_batch
            
            assert shots_x.shape[0:2] == (mt_batch_sz, n_way*n_shot)
            assert query_x.shape[0:2] == (mt_batch_sz, n_way*n_query)
            assert shots_y.shape == (mt_batch_sz, n_way*n_shot)
            assert query_y.shape == (mt_batch_sz, n_way*n_query)

            # to cuda
            shots_x = shots_x.cuda()
            query_x = query_x.cuda()
            shots_y = shots_y.cuda()
            query_y = query_y.cuda()
            
            # compute logits and loss on query
            with torch.enable_grad() if is_training else torch.no_grad():
                logits, measurements_trajectory = \
                    self._algorithm.inner_loop_adapt(
                        support=shots_x,
                        support_labels=shots_y,
                        query=query_x,
                        n_way=n_way,
                        n_shot=n_shot,
                        n_query=n_query)

            logits = logits.reshape(-1, logits.size(-1))
            query_y = query_y.reshape(-1)
            assert logits.size(0) == query_y.size(0)
            loss = smooth_loss(
                logits, query_y, logits.shape[1], self._eps)
            accu = accuracy(logits, query_y) * 100.

            # metrics accumulation
            aggregate['mt_outer_loss'].append(loss.item())
            aggregate['mt_outer_accu'].append(accu)
            for k in measurements_trajectory:
                aggregate[k].append(measurements_trajectory[k][-1])
            
            # optimizer step
            if is_training:
                self._optimizer.zero_grad()
                loss.backward()
                if self._grad_clip > 0.:
                    # technically should have a method for algorithm.parameters()
                    clip_grad_norm_(self._algorithm._model.parameters(), 
                                    max_norm=self._grad_clip,
                                    norm_type='inf')
                self._optimizer.step()

            # logging
            if analysis and is_training:
                metrics = {}
                for name, values in aggregate.items():
                    metrics[name] = np.mean(values)
                self.log_output(epoch, i, metrics)
                aggregate = defaultdict(list)    

        # save model and log tboard for eval
        if is_training and self._save_folder is not None:
            save_name = "chkpt_{0:03d}.pt".format(epoch)
            save_path = os.path.join(self._save_folder, save_name)
            with open(save_path, 'wb') as f:
                torch.save({'model': self._algorithm._model.state_dict(),
                           'optimizer': self._optimizer}, f)


        results = {
            'train_loss_trajectory': {
                'loss': np.mean(aggregate['loss']), 
                'accu': np.mean(aggregate['accu']),
            },
            'test_loss_after': {
                'loss': np.mean(aggregate['mt_outer_loss']),
                'accu': np.mean(aggregate['mt_outer_accu']),
            }
        }
        mean, i95 = (np.mean(aggregate['mt_outer_accu']), 
            1.96 * np.std(aggregate['mt_outer_accu']) / np.sqrt(len(aggregate['mt_outer_accu'])))
        results['val_task_acc'] = "{:.2f} ± {:.2f} %".format(mean, i95) 
    
        return results
예제 #30
0
def train(args):
    """Run Training """

    global_step = 0
    best_metric = None
    best_model: Dict[str, torch.Tensor] = dict()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    writer = SummaryWriter(log_dir=args.output_dir)

    # We use flambe to do the data preprocessing
    # More info at https://flambe.ai
    print("Performing preprocessing (possibly download embeddings).")
    embeddings = args.embeddings if args.use_pretrained_embeddings else None
    text_field = TextField(lower=args.lowercase,
                           embeddings=embeddings,
                           embeddings_format='gensim')
    label_field = LabelField()
    transforms = {'text': text_field, 'label': label_field}
    dataset = TabularDataset.from_path(
        args.train_path,
        args.val_path,
        sep=',' if args.file_type == 'csv' else '\t',
        transform=transforms)

    # Create samplers
    train_sampler = EpisodicSampler(dataset.train,
                                    n_support=args.n_support,
                                    n_query=args.n_query,
                                    n_episodes=args.n_episodes,
                                    n_classes=args.n_classes)

    # The train_eval_sampler is used to computer prototypes over the full dataset
    train_eval_sampler = BaseSampler(dataset.train,
                                     batch_size=args.eval_batch_size)
    val_sampler = BaseSampler(dataset.val, batch_size=args.eval_batch_size)

    if args.device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device

    # Build model, criterion and optimizers
    model = PrototypicalTextClassifier(
        vocab_size=dataset.text.vocab_size,
        distance=args.distance,
        embedding_dim=args.embedding_dim,
        pretrained_embeddings=dataset.text.embedding_matrix,
        rnn_type='sru',
        n_layers=args.n_layers,
        hidden_dim=args.hidden_dim,
        freeze_pretrained_embeddings=True)

    loss_fn = nn.CrossEntropyLoss()

    parameters = (p for p in model.parameters() if p.requires_grad)
    optimizer = torch.optim.Adam(parameters, lr=args.learning_rate)

    print("Beginning training.")
    for epoch in range(args.num_epochs):

        ######################
        #       TRAIN        #
        ######################

        print(f'Epoch: {epoch}')

        model.train()

        with torch.enable_grad():
            for batch in train_sampler:
                # Zero the gradients and clear the accumulated loss
                optimizer.zero_grad()

                # Move to device
                batch = tuple(t.to(device) for t in batch)
                query, query_label, support, support_label = batch

                # Compute loss
                pred = model(query, support, support_label)
                loss = loss_fn(pred, query_label)
                loss.backward()

                # Clip gradients if necessary
                if args.max_grad_norm is not None:
                    clip_grad_norm_(model.parameters(), args.max_grad_norm)

                writer.add_scalar('Training/Loss', loss.item(), global_step)

                # Optimize
                optimizer.step()
                global_step += 1

            # Zero the gradients when exiting a train step
            optimizer.zero_grad()

        #########################
        #       EVALUATE        #
        #########################

        model.eval()

        with torch.no_grad():

            # First compute prototypes over the training data
            encodings, labels = [], []
            for text, label in train_eval_sampler:
                padding_mask = (text != model.padding_idx).byte()
                text_embeddings = model.embedding_dropout(
                    model.embedding(text))
                text_encoding = model.encoder(text_embeddings,
                                              padding_mask=padding_mask)
                labels.append(label.cpu())
                encodings.append(text_encoding.cpu())
            # Compute prototypes
            encodings = torch.cat(encodings, dim=0)
            labels = torch.cat(labels, dim=0)
            prototypes = model.compute_prototypes(encodings, labels).to(device)

            _preds, _targets = [], []
            for batch in val_sampler:
                # Move to device
                source, target = tuple(t.to(device) for t in batch)

                pred = model(source, prototypes=prototypes)
                _preds.append(pred.cpu())
                _targets.append(target.cpu())

            preds = torch.cat(_preds, dim=0)
            targets = torch.cat(_targets, dim=0)

            val_loss = loss_fn(preds, targets).item()
            val_metric = (pred.argmax(dim=1) == target).float().mean().item()

        # Update best model
        if best_metric is None or val_metric > best_metric:
            best_metric = val_metric
            best_model_state = model.state_dict()
            for k, t in best_model_state.items():
                best_model_state[k] = t.cpu().detach()
            best_model = best_model_state

        # Log metrics
        print(f'Validation loss: {val_loss}')
        print(f'Validation accuracy: {val_metric}')
        writer.add_scalar('Validation/Loss', val_loss, epoch)
        writer.add_scalar('Validation/Accuracy', val_metric, epoch)

    # Save the best model
    print("Finisehd training.")
    torch.save(best_model, os.path.join(args.output_dir, 'model.pt'))