コード例 #1
0
ファイル: train_fm.py プロジェクト: yair-schiff/pytorch-PPUU
def train(nbatches, npred):
    model.train()
    total_loss_i, total_loss_s, total_loss_p = 0, 0, 0
    for i in range(nbatches):
        optimizer.zero_grad()
        inputs, actions, targets, _, _ = dataloader.get_batch_fm(
            'train', npred)
        inputs = utils.make_variables(inputs)
        targets = utils.make_variables(targets)
        actions = Variable(actions)
        pred, loss_p = model(inputs, actions, targets, z_dropout=opt.z_dropout)
        loss_p = loss_p[0]
        loss_i, loss_s = compute_loss(targets, pred)
        loss = loss_i + loss_s + opt.beta * loss_p

        # VAEs get NaN loss sometimes, so check for it
        if not math.isnan(loss.item()):
            loss.backward(retain_graph=False)
            if not math.isnan(utils.grad_norm(model).item()):
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               opt.grad_clip)
                optimizer.step()

        total_loss_i += loss_i.item()
        total_loss_s += loss_s.item()
        total_loss_p += loss_p.item()
        del inputs, actions, targets

    total_loss_i /= nbatches
    total_loss_s /= nbatches
    total_loss_p /= nbatches
    return total_loss_i, total_loss_s, total_loss_p
コード例 #2
0
def start(what, nbatches, npred):
    train = True if what is 'train' else False
    model.train()
    model.policy_net.train()
    n_updates, grad_norm = 0, 0
    total_losses = dict(
        proximity=0,
        uncertainty=0,
        lane=0,
        offroad=0,
        action=0,
        policy=0,
    )
    for j in range(nbatches):
        inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm(
            what, npred)
        pred, actions = planning.train_policy_net_mpur(
            model,
            inputs,
            targets,
            car_sizes,
            n_models=10,
            lrt_z=opt.lrt_z,
            n_updates_z=opt.z_updates,
            infer_z=opt.infer_z)
        pred['policy'] = pred['proximity'] + \
                         opt.u_reg * pred['uncertainty'] + \
                         opt.lambda_l * pred['lane'] + \
                         opt.lambda_a * pred['action'] + \
                         opt.lambda_o * pred['offroad']

        if not math.isnan(pred['policy'].item()):
            if train:
                optimizer.zero_grad()
                pred['policy'].backward()  # back-propagation through time!
                grad_norm += utils.grad_norm(model.policy_net).item()
                torch.nn.utils.clip_grad_norm_(model.policy_net.parameters(),
                                               opt.grad_clip)
                optimizer.step()
            for loss in total_losses:
                total_losses[loss] += pred[loss].item()
            n_updates += 1
        else:
            print('warning, NaN')  # Oh no... Something got quite f****d up!
            ipdb.set_trace()

        if j == 0 and opt.save_movies and train:
            # save videos of normal and adversarial scenarios
            for b in range(opt.batch_size):
                state_img = pred['state_img'][b]
                state_vct = pred['state_vct'][b]
                utils.save_movie(opt.model_file + f'.mov/sampled/mov{b}',
                                 state_img, state_vct, None, actions[b])

        del inputs, actions, targets, pred

    for loss in total_losses:
        total_losses[loss] /= n_updates
    if train: print(f'[avg grad norm: {grad_norm / n_updates:.4f}]')
    return total_losses
コード例 #3
0
def train(nbatches, npred):
    model.train()
    total_loss = 0
    for i in range(nbatches):
        optimizer.zero_grad()
        inputs, actions, targets, _, _ = dataloader.get_batch_fm(
            'train', npred)
        pred, _ = model(inputs, actions, targets, z_dropout=0)
        pred_cost = cost(
            pred[0].view(opt.batch_size * opt.npred, 1, 3, opt.height,
                         opt.width), pred[1].view(opt.batch_size * opt.npred,
                                                  1, 4))
        loss = F.mse_loss(pred_cost.view(opt.batch_size, opt.npred, 2),
                          targets[2])
        if not math.isnan(loss.item()):
            loss.backward(retain_graph=False)
            if not math.isnan(utils.grad_norm(model).item()):
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               opt.grad_clip)
                optimizer.step()
            total_loss += loss.item()
        del inputs, actions, targets

    total_loss /= nbatches
    return total_loss
コード例 #4
0
    def adapt(self, indices, inputs, lbls):
        context_x, context_y = self.mem.fetch(inputs)

        out = self.forward(inputs)
        self.criterion.reduce = False
        loss_out = self.criterion(out, lbls.squeeze(0))
        self.criterion.reduce = True
        gnorms = []
        for loss in loss_out:
            self.zero_grad()
            loss.backward(retain_graph=True)
            gnorm = utils.grad_norm(self.parameters())
            gnorms.append(gnorm.item())
        self.zero_grad()
        gnorms = np.array(gnorms)
        gnorms /= gnorms.sum()
        indices_sampled = np.random.choice(len(gnorms),
                                           self.bsz_sampling,
                                           p=gnorms)
        indices_sampled = list(set(indices_sampled))
        self.mem.add(inputs[indices_sampled], lbls[indices_sampled])

        if context_x is not None and\
                context_x is not None:
            out = self.forward(torch.cat([inputs, context_x], dim=0))
            lbl = torch.cat([lbls, context_y], dim=0)
        else:
            out = self.forward(inputs)
            lbl = lbls
        loss = self.criterion(out, lbl.squeeze(0))
        self.nsteps += 1
        return loss
コード例 #5
0
    def adapt(self, indices, inputs, lbls):
        context_i, context_x, context_y, context_g = \
            self.mem.fetch(inputs)

        out = self.baby_mlp(inputs)
        self.criterion.reduce = False
        loss_out = self.criterion(out, lbls.squeeze(0))
        self.criterion.reduce = True
        gnorms = []
        for loss in loss_out:
            self.baby_mlp.zero_grad()
            loss.backward(retain_graph=True)
            gnorm = utils.grad_norm(self.baby_mlp.parameters())
            gnorms.append(gnorm)
        self.baby_mlp.zero_grad()

        if self.nsteps % self.add_per == 0:
            self.mem.add(indices, inputs, lbls, gnorms)

        if context_x is not None and \
                context_x is not None:
            out = self.forward(torch.cat([inputs, context_x], dim=0))
            lbl = torch.cat([lbls, context_y], dim=0)
        else:
            out = self.forward(inputs)
            lbl = lbls

        loss = self.criterion(out, lbl.squeeze(0))
        self.nsteps += 1

        return loss
コード例 #6
0
ファイル: bism.py プロジェクト: baofff/BiSM
def _update(v, runner, criterion):
    opt_model, opt_q = criterion.opts["model"], criterion.opts["q"]
    sch_model, sch_q = criterion.schs["model"], criterion.schs["q"]

    # update q
    runner.model.requires_grad_(False)
    runner.q.requires_grad_(True)
    for i in range(config.get("update", "n_inner_loops")):
        opt_q.zero_grad()
        criterion.mid_vals["inner_loss"] = inner_loss = criterion.inner_loss(v).mean()
        inner_loss.backward()  # backward
        if config.get("update", "gradient_clip", default=False):
            clip_grad_norm_(runner.q.parameters(), 0.5)
        opt_q.step()  # step
    sch_q.step()

    # update model
    runner.model.requires_grad_(True)
    runner.q.requires_grad_(True)
    backup_q_state_dict = runner.q.state_dict()
    with higher.innerloop_ctx(runner.q, opt_q) as (fq, diffopt_q):
        criterion.q = fq
        for i in range(config.get("update", "n_unroll")):
            inner_loss = criterion.inner_loss(v).mean()
            diffopt_q.step(inner_loss)
        criterion.loss_val = loss = criterion.loss(v).mean()
        opt_model.zero_grad()
        runner.model.requires_grad_(True)
        runner.q.requires_grad_(False)
        loss.backward()
        if config.get("update", "gradient_clip", default=False):
            clip_grad_norm_(runner.model.parameters(), 0.5)
        opt_model.step()
        sch_model.step()
    runner.q.load_state_dict(backup_q_state_dict)
    criterion.q = runner.q

    criterion.mid_vals["grad_model"] = grad_norm(runner.model)
    criterion.mid_vals["grad_q"] = grad_norm(runner.q)
コード例 #7
0
def finish_episode():
    R = 0
    saved_actions = model.saved_actions
    policy_losses = []
    value_losses = []
    rewards = []
    for r in model.rewards[::-1]:
        R = r + opt.gamma * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards).cuda()
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for (log_prob, value), r in zip(saved_actions, rewards):
        reward = r - value.item()
        policy_losses.append(-log_prob * reward)
        value_losses.append(F.smooth_l1_loss(value, torch.tensor([r]).cuda()))
    optimizer.zero_grad()
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
    loss.backward()
    print(utils.grad_norm(model).item())
    optimizer.step()
    del model.rewards[:]
    del model.saved_actions[:]
コード例 #8
0
def main():
    best_score = 0
    args = get_train_args()
    print(json.dumps(args.__dict__, indent=4))

    # Reading the int indexed text dataset
    train_data = np.load(os.path.join(args.input, args.data + ".train.npy"))
    train_data = train_data.tolist()
    dev_data = np.load(os.path.join(args.input, args.data + ".valid.npy"))
    dev_data = dev_data.tolist()
    test_data = np.load(os.path.join(args.input, args.data + ".test.npy"))
    test_data = test_data.tolist()

    # Reading the vocab file
    with open(os.path.join(args.input, args.data + '.vocab.pickle'),
              'rb') as f:
        id2w = pickle.load(f)

    args.id2w = id2w
    args.n_vocab = len(id2w)

    # Define Model
    model = net.Transformer(args)

    tally_parameters(model)
    if args.gpu >= 0:
        model.cuda(args.gpu)
    print(model)

    optimizer = optim.TransformerAdamTrainer(model, args)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.model_file):
            print("=> loading checkpoint '{}'".format(args.model_file))
            checkpoint = torch.load(args.model_file)
            args.start_epoch = checkpoint['epoch']
            best_score = checkpoint['best_score']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.model_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.model_file))

    src_data, trg_data = list(zip(*train_data))
    total_src_words = len(list(itertools.chain.from_iterable(src_data)))
    total_trg_words = len(list(itertools.chain.from_iterable(trg_data)))
    iter_per_epoch = (total_src_words + total_trg_words) // args.wbatchsize
    print('Approximate number of iter/epoch =', iter_per_epoch)
    time_s = time()

    global_steps = 0
    for epoch in range(args.start_epoch, args.epoch):
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            args.wbatchsize,
            key=lambda x: data.utils.interleave_keys(len(x[0]), len(x[1])),
            batch_size_fn=batch_size_func,
            random_shuffler=data.iterator.RandomShuffler())
        report_stats = utils.Statistics()
        train_stats = utils.Statistics()
        valid_stats = utils.Statistics()

        if args.debug:
            grad_norm = 0.
        for num_steps, train_batch in enumerate(train_iter):
            global_steps += 1
            model.train()
            optimizer.zero_grad()
            src_iter = list(zip(*train_batch))[0]
            src_words = len(list(itertools.chain.from_iterable(src_iter)))
            report_stats.n_src_words += src_words
            train_stats.n_src_words += src_words
            in_arrays = utils.seq2seq_pad_concat_convert(train_batch, -1)
            loss, stat = model(*in_arrays)
            loss.backward()
            if args.debug:
                norm = utils.grad_norm(model.parameters())
                grad_norm += norm
                if global_steps % args.report_every == 0:
                    print("> Gradient Norm: %1.4f" % (grad_norm /
                                                      (num_steps + 1)))
            optimizer.step()

            report_stats.update(stat)
            train_stats.update(stat)
            report_stats = report_func(epoch, num_steps, iter_per_epoch,
                                       time_s, report_stats, args.report_every)

            if (global_steps + 1) % args.eval_steps == 0:
                dev_iter = data.iterator.pool(
                    dev_data,
                    args.wbatchsize,
                    key=lambda x: data.utils.interleave_keys(
                        len(x[0]), len(x[1])),
                    batch_size_fn=batch_size_func,
                    random_shuffler=data.iterator.RandomShuffler())

                for dev_batch in dev_iter:
                    model.eval()
                    in_arrays = utils.seq2seq_pad_concat_convert(dev_batch, -1)
                    loss_test, stat = model(*in_arrays)
                    valid_stats.update(stat)

                print('Train perplexity: %g' % train_stats.ppl())
                print('Train accuracy: %g' % train_stats.accuracy())

                print('Validation perplexity: %g' % valid_stats.ppl())
                print('Validation accuracy: %g' % valid_stats.accuracy())

                bleu_score, _ = CalculateBleu(model,
                                              dev_data,
                                              'Dev Bleu',
                                              batch=args.batchsize // 4,
                                              beam_size=args.beam_size,
                                              alpha=args.alpha,
                                              max_sent=args.max_sent_eval)()
                if args.metric == "bleu":
                    score = bleu_score
                elif args.metric == "accuracy":
                    score = valid_stats.accuracy()

                is_best = score > best_score
                best_score = max(score, best_score)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'best_score': best_score,
                        'optimizer': optimizer.state_dict(),
                        'opts': args,
                    }, is_best, args.model_file, args.best_model_file)

    # BLEU score on Dev and Test Data
    checkpoint = torch.load(args.best_model_file)
    print("=> loaded checkpoint '{}' (epoch {}, best score {})".format(
        args.best_model_file, checkpoint['epoch'], checkpoint['best_score']))
    model.load_state_dict(checkpoint['state_dict'])

    print('Dev Set BLEU Score')
    _, dev_hyp = CalculateBleu(model,
                               dev_data,
                               'Dev Bleu',
                               batch=args.batchsize // 4,
                               beam_size=args.beam_size,
                               alpha=args.alpha)()
    save_output(dev_hyp, id2w, args.dev_hyp)

    print('Test Set BLEU Score')
    _, test_hyp = CalculateBleu(model,
                                test_data,
                                'Test Bleu',
                                batch=args.batchsize // 4,
                                beam_size=args.beam_size,
                                alpha=args.alpha)()
    save_output(test_hyp, id2w, args.test_hyp)
コード例 #9
0
def train(nbatches, npred):
    model.train()
    model.policy_net.train()
    total_loss_c, total_loss_u, total_loss_l, total_loss_a, n_updates, grad_norm = 0, 0, 0, 0, 0, 0
    total_loss_policy = 0
    for j in range(nbatches):
        optimizer.zero_grad()
        inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm(
            'train', npred)
        inputs = utils.make_variables(inputs)
        targets = utils.make_variables(targets)
        pred, actions, pred_adv = planning.train_policy_net_mpur(
            model,
            inputs,
            targets,
            car_sizes,
            n_models=10,
            lrt_z=opt.lrt_z,
            n_updates_z=opt.z_updates,
            infer_z=(opt.infer_z == 1))
        loss_c = pred[2]  # proximity cost
        loss_l = pred[3]  # lane cost
        loss_u = pred[4]  # uncertainty cost
        loss_a = actions.norm(2, 2).pow(2).mean()  # action regularisation
        loss_policy = loss_c + opt.u_reg * loss_u + opt.lambda_l * loss_l + opt.lambda_a * loss_a

        if not math.isnan(loss_policy.item()):
            loss_policy.backward()  # back-propagation through time!
            grad_norm += utils.grad_norm(model.policy_net).item()
            torch.nn.utils.clip_grad_norm_(model.policy_net.parameters(),
                                           opt.grad_clip)
            optimizer.step()
            total_loss_c += loss_c.item()  # proximity cost
            total_loss_u += loss_u.item()  # uncertainty (reg.)
            total_loss_a += loss_a.item()  # action (reg.)
            total_loss_l += loss_l.item()  # lane cost
            total_loss_policy += loss_policy.item()  # overall total cost
            n_updates += 1
        else:
            print('warning, NaN')  # Oh no... Something got quite f****d up!
            pdb.set_trace()

        if j == 0 and opt.save_movies:
            # save videos of normal and adversarial scenarios
            for b in range(opt.batch_size):
                utils.save_movie(opt.model_file + f'.mov/sampled/mov{b}',
                                 pred[0][b], pred[1][b], None, actions[b])
                if pred_adv[0] is not None:
                    utils.save_movie(
                        opt.model_file + f'.mov/adversarial/mov{b}',
                        pred_adv[0][b], pred_adv[1][b], None, actions[b])

        del inputs, actions, targets, pred

    total_loss_c /= n_updates
    total_loss_u /= n_updates
    total_loss_a /= n_updates
    total_loss_l /= n_updates
    total_loss_policy /= n_updates
    print(f'[avg grad norm: {grad_norm / n_updates}]')
    return total_loss_c, total_loss_l, total_loss_u, total_loss_a, total_loss_policy
コード例 #10
0
ファイル: train.py プロジェクト: imohammad12/multilingual_nmt
def main():
    best_score = 0
    args = get_train_args()
    logger = get_logger(args.log_path)
    logger.info(json.dumps(args.__dict__, indent=4))

    # Set seed value
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    if args.gpu:
        torch.cuda.manual_seed_all(args.seed)

    # Reading the int indexed text dataset
    train_data = np.load(os.path.join(args.input, args.data + ".train.npy"),
                         allow_pickle=True)
    train_data = train_data.tolist()
    dev_data = np.load(os.path.join(args.input, args.data + ".valid.npy"),
                       allow_pickle=True)
    dev_data = dev_data.tolist()
    test_data = np.load(os.path.join(args.input, args.data + ".test.npy"),
                        allow_pickle=True)
    test_data = test_data.tolist()

    # Reading the vocab file
    with open(os.path.join(args.input, args.data + '.vocab.pickle'),
              'rb') as f:
        id2w = pickle.load(f)

    args.id2w = id2w
    args.n_vocab = len(id2w)

    # Define Model
    model = eval(args.model)(args)
    model.apply(init_weights)

    tally_parameters(model)
    if args.gpu >= 0:
        model.cuda(args.gpu)
    logger.info(model)

    if args.optimizer == 'Noam':
        optimizer = NoamAdamTrainer(model, args)
    elif args.optimizer == 'Adam':
        params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = torch.optim.Adam(params,
                                     lr=args.learning_rate,
                                     betas=(args.optimizer_adam_beta1,
                                            args.optimizer_adam_beta2),
                                     eps=args.optimizer_adam_epsilon)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='max',
                                                               factor=0.7,
                                                               patience=7,
                                                               verbose=True)
    elif args.optimizer == 'Yogi':
        params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = Yogi(params,
                         lr=args.learning_rate,
                         betas=(args.optimizer_adam_beta1,
                                args.optimizer_adam_beta2),
                         eps=args.optimizer_adam_epsilon)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='max',
                                                               factor=0.7,
                                                               patience=7,
                                                               verbose=True)

    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    ema = ExponentialMovingAverage(decay=args.ema_decay)
    ema.register(model.state_dict())

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.model_file):
            logger.info("=> loading checkpoint '{}'".format(args.model_file))
            checkpoint = torch.load(args.model_file)
            args.start_epoch = checkpoint['epoch']
            best_score = checkpoint['best_score']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.model_file, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                args.model_file))

    src_data, trg_data = list(zip(*train_data))
    total_src_words = len(list(itertools.chain.from_iterable(src_data)))
    total_trg_words = len(list(itertools.chain.from_iterable(trg_data)))
    iter_per_epoch = (total_src_words + total_trg_words) // (2 *
                                                             args.wbatchsize)
    logger.info('Approximate number of iter/epoch = {}'.format(iter_per_epoch))
    time_s = time()

    global_steps = 0
    num_grad_steps = 0
    if args.grad_norm_for_yogi and args.optimizer == 'Yogi':
        args.start_epoch = -1
        l2_norm = 0.0
        parameters = list(
            filter(lambda p: p.requires_grad is True, model.parameters()))
        n_params = sum([p.nelement() for p in parameters])

    for epoch in range(args.start_epoch, args.epoch):
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            args.wbatchsize,
            key=lambda x: (len(x[0]), len(x[1])),
            batch_size_fn=batch_size_fn,
            random_shuffler=data.iterator.RandomShuffler())
        report_stats = utils.Statistics()
        train_stats = utils.Statistics()
        if args.debug:
            grad_norm = 0.
        for num_steps, train_batch in enumerate(train_iter):
            global_steps += 1
            model.train()
            if args.grad_accumulator_count == 1:
                optimizer.zero_grad()
            elif num_grad_steps % args.grad_accumulator_count == 0:
                optimizer.zero_grad()
            src_iter = list(zip(*train_batch))[0]
            src_words = len(list(itertools.chain.from_iterable(src_iter)))
            report_stats.n_src_words += src_words
            train_stats.n_src_words += src_words
            in_arrays = utils.seq2seq_pad_concat_convert(train_batch, -1)
            if len(args.multi_gpu) > 1:
                loss_tuple, stat_tuple = zip(
                    *dp(model, in_arrays, device_ids=args.multi_gpu))
                n_total = sum([obj.n_words.item() for obj in stat_tuple])
                n_correct = sum([obj.n_correct.item() for obj in stat_tuple])
                loss = 0
                for l_, s_ in zip(loss_tuple, stat_tuple):
                    loss += l_ * s_.n_words.item()
                loss /= n_total
                stat = utils.Statistics(loss=loss.data.cpu() * n_total,
                                        n_correct=n_correct,
                                        n_words=n_total)
            else:
                loss, stat = model(*in_arrays)

            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()
            if epoch == -1 and args.grad_norm_for_yogi and args.optimizer == 'Yogi':
                l2_norm += (utils.grad_norm(model.parameters())**2) / n_params
                continue
            num_grad_steps += 1
            if args.debug:
                norm = utils.grad_norm(model.parameters())
                grad_norm += norm
                if global_steps % args.report_every == 0:
                    logger.info("> Gradient Norm: %1.4f" % (grad_norm /
                                                            (num_steps + 1)))
            if args.grad_accumulator_count == 1:
                optimizer.step()
                ema.apply(model.state_dict(keep_vars=True))
            elif num_grad_steps % args.grad_accumulator_count == 0:
                optimizer.step()
                ema.apply(model.state_dict(keep_vars=True))
                num_grad_steps = 0
            report_stats.update(stat)
            train_stats.update(stat)
            report_stats = report_func(epoch, num_steps, iter_per_epoch,
                                       time_s, report_stats, args.report_every)

            valid_stats = utils.Statistics()
            if global_steps % args.eval_steps == 0:
                with torch.no_grad():
                    dev_iter = data.iterator.pool(
                        dev_data,
                        args.wbatchsize,
                        key=lambda x: (len(x[0]), len(x[1])),
                        batch_size_fn=batch_size_fn,
                        random_shuffler=data.iterator.RandomShuffler())

                    for dev_batch in dev_iter:
                        model.eval()
                        in_arrays = utils.seq2seq_pad_concat_convert(
                            dev_batch, -1)
                        if len(args.multi_gpu) > 1:
                            _, stat_tuple = zip(*dp(
                                model, in_arrays, device_ids=args.multi_gpu))
                            n_total = sum(
                                [obj.n_words.item() for obj in stat_tuple])
                            n_correct = sum(
                                [obj.n_correct.item() for obj in stat_tuple])
                            dev_loss = sum([obj.loss for obj in stat_tuple])
                            stat = utils.Statistics(loss=dev_loss,
                                                    n_correct=n_correct,
                                                    n_words=n_total)
                        else:
                            _, stat = model(*in_arrays)
                        valid_stats.update(stat)

                    logger.info('Train perplexity: %g' % train_stats.ppl())
                    logger.info('Train accuracy: %g' % train_stats.accuracy())

                    logger.info('Validation perplexity: %g' %
                                valid_stats.ppl())
                    logger.info('Validation accuracy: %g' %
                                valid_stats.accuracy())

                    if args.metric == "accuracy":
                        score = valid_stats.accuracy()
                    elif args.metric == "bleu":
                        score, _ = CalculateBleu(
                            model,
                            dev_data,
                            'Dev Bleu',
                            batch=args.batchsize // 4,
                            beam_size=args.beam_size,
                            alpha=args.alpha,
                            max_sent=args.max_sent_eval)(logger)

                    # Threshold Global Steps to save the model
                    if not (global_steps % 2000):
                        print('saving')
                        is_best = score > best_score
                        best_score = max(score, best_score)
                        save_checkpoint(
                            {
                                'epoch': epoch + 1,
                                'state_dict': model.state_dict(),
                                'state_dict_ema': ema.shadow_variable_dict,
                                'best_score': best_score,
                                'optimizer': optimizer.state_dict(),
                                'opts': args,
                            }, is_best, args.model_file, args.best_model_file)

                    if args.optimizer == 'Adam' or args.optimizer == 'Yogi':
                        scheduler.step(score)

        if epoch == -1 and args.grad_norm_for_yogi and args.optimizer == 'Yogi':
            optimizer.v_init = l2_norm / (num_steps + 1)
            logger.info("Initializing Yogi Optimizer (v_init = {})".format(
                optimizer.v_init))

    # BLEU score on Dev and Test Data
    checkpoint = torch.load(args.best_model_file)
    logger.info("=> loaded checkpoint '{}' (epoch {}, best score {})".format(
        args.best_model_file, checkpoint['epoch'], checkpoint['best_score']))
    model.load_state_dict(checkpoint['state_dict'])

    logger.info('Dev Set BLEU Score')
    _, dev_hyp = CalculateBleu(model,
                               dev_data,
                               'Dev Bleu',
                               batch=args.batchsize // 4,
                               beam_size=args.beam_size,
                               alpha=args.alpha,
                               max_decode_len=args.max_decode_len)(logger)
    save_output(dev_hyp, id2w, args.dev_hyp)

    logger.info('Test Set BLEU Score')
    _, test_hyp = CalculateBleu(model,
                                test_data,
                                'Test Bleu',
                                batch=args.batchsize // 4,
                                beam_size=args.beam_size,
                                alpha=args.alpha,
                                max_decode_len=args.max_decode_len)(logger)
    save_output(test_hyp, id2w, args.test_hyp)

    # Loading EMA state dict
    model.load_state_dict(checkpoint['state_dict_ema'])
    logger.info('Dev Set BLEU Score')
    _, dev_hyp = CalculateBleu(model,
                               dev_data,
                               'Dev Bleu',
                               batch=args.batchsize // 4,
                               beam_size=args.beam_size,
                               alpha=args.alpha,
                               max_decode_len=args.max_decode_len)(logger)
    save_output(dev_hyp, id2w, args.dev_hyp + '.ema')

    logger.info('Test Set BLEU Score')
    _, test_hyp = CalculateBleu(model,
                                test_data,
                                'Test Bleu',
                                batch=args.batchsize // 4,
                                beam_size=args.beam_size,
                                alpha=args.alpha,
                                max_decode_len=args.max_decode_len)(logger)
    save_output(test_hyp, id2w, args.test_hyp + '.ema')
コード例 #11
0
def start(what,
          nbatches,
          npred,
          split='train',
          return_per_instance_values=False,
          threshold=0):
    train = True if what is 'train' else False
    evaluate = True if what is 'eval' else False
    finetune_train = True if split is 'finetune_train' else False
    finetune_sim = True if split is 'finetune_sim' else False

    model.train()
    model.policy_net.train()
    n_updates, grad_norm = 0, 0

    if return_per_instance_values:
        total_losses = dict(
            proximity=[],
            uncertainty=[],
            lane=[],
            offroad=[],
            action=[],
            policy=[],
            episode_timestep_pairs=[],
        )

    else:
        total_losses = dict(
            proximity=0,
            uncertainty=0,
            lane=0,
            offroad=0,
            action=0,
            policy=0,
        )

    if evaluate:
        episode_cost_progression = {}

    iterable = range(nbatches)
    if evaluate or finetune_train:
        total_instances = dataloader.get_total_instances(split, what)
        print(f"total_instances in {split}: {total_instances}")
        iterable = range(0, total_instances, opt.batch_size)
#         nbatches = None

    step = 0
    for j in iterable:
        #             print("j:",j,n_updates)
        #         with tqdm(total=len(iterable)) as progress_bar:
        #             start = time.time()
        if not evaluate:
            inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm(
                split,
                npred,
                cuda=(True if torch.cuda.is_available() and not opt.no_cuda
                      else False),
                all_batches=(True if finetune_train else False))
        else:
            e_index, inputs, actions, targets, ids, car_sizes = dataloader.get_batch_fm(
                split,
                npred,
                return_episode_index=True,
                cuda=(True if torch.cuda.is_available() and not opt.no_cuda
                      else False),
                all_batches=True if finetune_train else False,
                randomize=(True if
                           (finetune_train or finetune_sim) else False))
            #                 print(np.unique(e_index))#, type(e_index))
            if -1 in e_index[:, 0]:
                print("breaking now")
                break

        pred, actions = planning.train_policy_net_mpur(
            model,
            inputs,
            targets,
            car_sizes,
            n_models=10,
            lrt_z=opt.lrt_z,
            n_updates_z=opt.z_updates,
            infer_z=opt.infer_z,
            no_cuda=opt.no_cuda,
            return_per_instance_values=(True if evaluate else False))
        #             print("costs: ",pred["proximity"].shape, pred['lane'].shape, pred['action'].shape, pred['offroad'].shape)
        pred['policy'] = pred['proximity'] + \
                         opt.u_reg * pred['uncertainty'] + \
                         opt.lambda_l * pred['lane'] + \
                         opt.lambda_a * pred['action'] + \
                         opt.lambda_o * pred['offroad']

        #             print(torch.mean(pred['policy']))
        #             print(pred['policy'].shape)
        #             print("time for loading batches and forward pass: "******"episode" not in key
                        }

                    if instance_loss >= threshold:
                        episode_index, timestep = e_index[b_i]

                        for loss in total_losses:
                            if loss != 'episode_timestep_pairs':
                                total_losses[loss].append(
                                    torch.mean(pred[loss][b_i]).detach().cpu())
                                episode_cost_progression[
                                    e_index[b_i][0]][loss].append(
                                        torch.mean(
                                            pred[loss][b_i]).detach().cpu())
                            else:
                                total_losses[loss].append(
                                    [episode_index, timestep, instance_loss])

#                             print(type(dataloader.finetune_dict))
#                             print(dataloader.finetune_dict)
#                             if episode_index not in dataloader.finetune_dict:
#                                 dataloader.finetune_dict[episode_index] = []

#                             nframes = opt.npred + opt.ncond
#                             min_range = max(0,timestep-nframes)
#                             max_range = min(len(dataloader.images[episode_index]),timestep+100)-50
#                             for frame_index in range(min_range,max_range,opt.finetune_nframes_overlap):
#                                 if frame_index not in dataloader.finetune_dict[episode_index]:
#                                     dataloader.finetune_dict[episode_index].append(frame_index)

#                             else:
# #                                 print(type(inputs),type(inputs[0]),type(inputs[0][b_i]))

#                                 if finetune_inputs["inputs"]:
#                                     for f_i, input_i in enumerate(finetune_inputs["inputs"]):
#                                         finetune_inputs["inputs"][input_i].append(inputs[f_i][b_i:b_i+1])
#                                     for f_i, input_i in enumerate(finetune_inputs["targets"]):
#                                         finetune_inputs["targets"][input_i].append(targets[f_i][b_i:b_i+1])
#                                 else:
#                                     finetune_inputs["inputs"] = {f_i : [inputs[f_i][b_i:b_i+1]] for f_i in range(len(inputs))}
#                                     finetune_inputs["targets"] = {f_i : [targets[f_i][b_i:b_i+1]] for f_i in range(len(targets))}

# #                                 print(type(car_sizes), len(car_sizes), type(car_sizes[0]), len(car_sizes[0]))

#                                 finetune_inputs["car_sizes"].append(car_sizes[b_i:b_i+1])

#                                 if len(finetune_inputs["car_sizes"]) == opt.batch_size:
# #                                     print([torch.cat(finetune_inputs["inputs"][input_i]).shape for input_i in finetune_inputs["inputs"]])
# #                                     for row in inputs:
# #                                         print(row.shape)
# #                                     print([torch.cat(finetune_inputs["targets"][input_i]).shape for input_i in finetune_inputs["targets"]])
# #                                     for row in targets:
# #                                         print(row.shape)
#                                     finetune_inputs["inputs"] = [torch.cat(finetune_inputs["inputs"][input_i]) for input_i in finetune_inputs["inputs"]]
#                                     finetune_inputs["targets"] = [torch.cat(finetune_inputs["targets"][input_i]) for input_i in finetune_inputs["targets"]]
#                                     finetune_inputs["car_sizes"] = torch.cat(finetune_inputs["car_sizes"])
# #                                     t1,t2,t3 = finetune_inputs["targets"]
# #                                     print(type(targets), len(targets), type(targets[0]), len(targets[0]))
# #                                     print(type(finetune_inputs["targets"]),len(finetune_inputs["targets"]),type(finetune_inputs["targets"][0]), len(finetune_inputs["targets"][0]))
#                                     pred, actions = planning.train_policy_net_mpur(
#                                         model, finetune_inputs["inputs"], finetune_inputs["targets"], finetune_inputs["car_sizes"], n_models=10, lrt_z=opt.lrt_z, n_updates_z=opt.z_updates, infer_z=opt.infer_z, no_cuda=opt.no_cuda )

#                                     pred['policy'] = pred['proximity'] + \
#                                                      opt.u_reg * pred['uncertainty'] + \
#                                                      opt.lambda_l * pred['lane'] + \
#                                                      opt.lambda_a * pred['action'] + \
#                                                      opt.lambda_o * pred['offroad']

#                                     for loss in finetune_losses:
#                                         finetune_losses[loss] += pred[loss]

#                                     optimizer.zero_grad()
#                                     pred['policy'].backward()  # back-propagation through time!
#                                     grad_norm += utils.grad_norm(model.policy_net).item()
#                                     torch.nn.utils.clip_grad_norm_(model.policy_net.parameters(), opt.grad_clip)
#                                     optimizer.step()

#                                     n_updates += 1
#                                     print(f"update no: {n_updates}")

#                                     finetune_inputs["inputs"] = {}
#                                     finetune_inputs["car_sizes"] = []
#                                     finetune_inputs["targets"] = {}

        else:
            print('warning, NaN')  # Oh no... Something got quite f****d up!
            ipdb.set_trace()

        if j == 0 and opt.save_movies and train:
            # save videos of normal and adversarial scenarios
            for b in range(opt.batch_size):
                state_img = pred['state_img'][b]
                state_vct = pred['state_vct'][b]
                utils.save_movie(opt.model_file + f'.mov/sampled/mov{b}',
                                 state_img, state_vct, None, actions[b])

#             step += len(actions)
#             progress_bar.update(step)
        del inputs, actions, targets, pred

        if n_updates == nbatches and not evaluate:
            #                 del dataloader[split]

            break
#             print("time for saving loss values for calc stats later: ", time.time() - start)

    if not evaluate:
        for loss in total_losses:
            total_losses[loss] /= n_updates
#     print(total_losses)

    if train or finetune_train or finetune_sim:
        print(f'[avg grad norm: {grad_norm / n_updates:.4f}]')

    if evaluate:
        pickle.dump(
            episode_cost_progression,
            open("policy_loss_stats/episode_cost_progression.pkl", 'wb+'))
    #     if finetune:
    #         return finetune_losses
    #     print("final j value: ", j)
    return total_losses
コード例 #12
0
    def train(self):
        start_t = time.time()
        print(f'Training started at {datetime.now()}')
        print(f'Total number of batches: {len(self.data_loader_train)}')

        best_valid_loss, best_train_epoch_loss, best_roc_auc = 10, 10, 0
        best_step_train_loss, best_step_valid_loss, best_step_valid_roc = 0, 0, 0
        drop_counter = 0
        loss_fn = self.model.loss()

        for epoch in range(self.config.num_epochs):
            epoch_loss = 0
            self.model.train()
            ctr = 0
            for ctr, (audio, target,
                      fname) in enumerate(self.data_loader_train):
                #ctr += 1
                drop_counter += 1
                audio = audio.to(self.device)
                target = target.to(self.device)

                # Time-frequency transform
                if self.transforms is not None:
                    audio = self.transforms(audio)

                # predict
                out = self.model(audio)
                loss = loss_fn(out, target)

                # back propagation
                self.optimizer.zero_grad()
                loss.backward()
                if self.config.clip_grad > 0:
                    clip_grad_norm_(self.model.parameters(),
                                    self.config.clip_grad)
                self.optimizer.step()

                epoch_loss += loss.item()

                # print log
                if (ctr) % self.config.print_every == 0:
                    print(
                        "[%s] Epoch [%d/%d] Iter [%d/%d] train loss: %.4f Elapsed: %s"
                        % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                           epoch + 1, self.config.num_epochs, ctr,
                           len(self.data_loader_train), loss.item(),
                           timedelta(seconds=time.time() - start_t)))

                if self.writer is not None:
                    step = epoch * len(self.data_loader_train) + ctr
                    self.writer.add_scalar('loss', loss.item(), step)
                    self.writer.add_scalar(
                        'learning_rate', self.optimizer.param_groups[0]['lr'],
                        step)
                    self.writer.add_scalar(
                        'grad_norm', utils.grad_norm(self.model.parameters()),
                        step)

            del audio, target
            epoch_loss = epoch_loss / len(self.data_loader_train)

            # validation
            valid_loss, scores, y_true, y_pred = self._validation(
                start_t, epoch)
            if self.scheduler is not None:
                if self.config.scheduler == 'plateau':
                    self.scheduler.step(valid_loss)
                else:
                    self.scheduler.step()

            # Log validation
            if self.writer is not None:
                step = epoch * len(self.data_loader_train) + ctr
                self.writer.add_scalar('valid_loss', valid_loss, step)
                self.writer.add_scalar('valid_roc_auc_macro',
                                       scores['roc_auc_macro'], step)
                if not self.config.debug_mode:
                    self.writer.add_figure(
                        'valid_class',
                        utils.compare_predictions(y_true,
                                                  y_pred,
                                                  filepath=None), step)

            # Save model, with respect to validation loss
            if valid_loss < best_valid_loss:
                # print('best model: %4f' % valid_loss)
                best_step_valid_loss = drop_counter
                best_valid_loss = valid_loss
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.config.checkpoint_dir,
                                 'best_model_valid_loss.pth'))

            # Save model, with respect to validation roc_auc
            if scores['roc_auc_macro'] > best_roc_auc:
                best_step_valid_roc = drop_counter
                best_roc_auc = scores['roc_auc_macro']
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.config.checkpoint_dir,
                                 'best_model_valid_roc.pth'))

            # Save best model according to training loss
            if epoch_loss < best_train_epoch_loss:
                best_step_train_loss = drop_counter
                best_train_epoch_loss = epoch_loss
                torch.save(
                    self.model.state_dict(),
                    os.path.join(self.config.checkpoint_dir,
                                 'best_model_train.pth'))

        print("{} Training finished. -----------------------  Elapsed: {}".
              format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                     timedelta(seconds=time.time() - start_t)))
        print(
            "Best step (validation loss) = {} . ".format(best_step_valid_loss))
        print("Best step (validation roc_auc) = {} .".format(
            best_step_valid_roc))
        print("Best step (training loss) = {} .".format(best_step_train_loss))

        # Save last model
        torch.save(
            self.model.state_dict(),
            os.path.join(self.config.checkpoint_dir, 'best_model_final.pth'))