示例#1
0
def run_train(train_iter, valid_iter, model, tree_lstm, V):
    opt_struct = torch.optim.Adadelta(list(model.parameters()),
                                      lr=config["lr_struct"])
    opt_params = torch.optim.Adadelta(list(tree_lstm.parameters()),
                                      lr=config["lr_params"])

    model.train()
    tree_lstm.train()
    losses = []
    Dist = TreeCRF
    step = 0
    trees = None
    for epoch in range(100):
        print("Epoch", epoch)

        for i, ex in enumerate(train_iter):
            step += 1
            words, lengths = ex.word
            label = ex.label
            batch = label.shape[0]
            _, N = words.shape
            words = words.cuda()

            def tree_reward(spans, K):
                new_spans = expand_spans(spans, words, K, V)
                g, labels, indices, topo = TreeLSTM.spans_to_dgl(new_spans)
                ret = tree_lstm(g, labels, indices, topo,
                                torch.cat([lengths for _ in range(K)]))
                ret = ret.view(K, batch, -1)
                return -ret[:, torch.arange(batch), label].view(K, batch)

            sc = SelfCritical(tree_reward)
            phi = model(words, lengths)
            dist = Dist(phi)
            structs, rewards, score, max_score = sc.forward(dist,
                                                            K=config["RL_K"])

            if config["train_model"]:
                opt_params.zero_grad()
                score.mean().backward()
                clip(tree_lstm.parameters())
                opt_params.step()
                opt_params.zero_grad()

            if config["method"] == "reinforce":
                opt_struct.zero_grad()
                entropy = dist.entropy
                r = dist.log_prob(structs)
                obj = rewards.mul(r).mean(-1).mean(-1)
                policy = (obj - config["entropy"] *
                          entropy.div(lengths.float().cuda()).mean())
                policy.backward()
                clip(model.parameters())
                opt_struct.step()
            losses.append(-max_score.mean().detach())

            # DEBUG
            if i % 50 == 9:
                print(torch.tensor(losses).mean(), words.shape)
                print("Round")
                print("Entropy", entropy.mean().item())
                print("Reward", rewards.mean().item())
                if i % 1000 == 9:
                    valid_loss = valid_sup(valid_iter, model, tree_lstm, V)
                    fname = "/tmp/checkpoint.%s.%0d.%0d.%s" % (
                        NAME,
                        epoch,
                        i,
                        valid_loss,
                    )
                    torch.save((model, tree_lstm), fname)
                    wandb.save(fname)
                    trees = valid_show(valid_iter, model)
                else:
                    print(valid_loss)

                wandb.log({
                    "entropy": entropy.mean(),
                    "valid_loss": valid_loss,
                    "reward": rewards.mean(),
                    "step": step,
                    "tree": trees,
                    "reward_var": rewards.var(),
                    "loss": torch.tensor(losses).mean(),
                })
                losses = []
示例#2
0
    def update_parameters(self, obss, exps):
        self._update_number += 1
        logs = {}

        # ===== Calculate losses =====

        dist, value, scores = self.acmodel(obss,
                                           mask_future=True,
                                           attn_custom_mask=self.attn_mask)

        entropy = dist.entropy().mean()

        policy_loss = -(dist.log_prob(exps.action) * exps.advantage).mean()

        value_loss = (value - exps.returnn).pow(2).mean()

        loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss

        # Update actor-critic

        self.optimizer.zero_grad()
        loss.backward()
        update_grad_norm = sum(
            p.grad.data.norm(2)**2 for p in self.acmodel.parameters())**0.5
        torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(),
                                       self.max_grad_norm)
        self.optimizer.step()

        # Log some values

        # Save attention scores heatmap every 100 updates
        if self.wandb_dir is not None and self._update_number % 100 == 0:
            import os
            import wandb
            import seaborn as sns
            import matplotlib.pyplot as plt
            attn_fig = (sns.heatmap(scores[0].detach().numpy(),
                                    xticklabels=10,
                                    yticklabels=10).get_figure())
            img_name_base = str(
                os.path.join(self.wandb_dir,
                             f'attn_scores_{self._update_number:04}'))
            attn_fig.savefig(img_name_base, fmt='png')
            wandb.save(img_name_base + '*')
            plt.clf()

            # # For debugging
            # labels_fig = (sns.heatmap(self.seq_labels_debug[0].detach().numpy(), xticklabels=10, yticklabels=10)
            #               .get_figure())
            # labels_fig_base = str(os.path.join(self.wandb_dir,
            #                                    f'episode_labels_{self._update_number:04}'))
            # labels_fig.savefig(labels_fig_base, fmt='png')
            # plt.clf()

            mask_fig = (sns.heatmap(self.attn_mask[0].detach().numpy(),
                                    xticklabels=10,
                                    yticklabels=10).get_figure())
            mask_fig_base = str(
                os.path.join(self.wandb_dir, f'mask_{self._update_number:04}'))
            mask_fig.savefig(mask_fig_base, fmt='png')
            plt.clf()

        with torch.no_grad():
            # evaluate KL divergence b/w old and new policy
            # policy under newly updated model
            dist, _, _ = self.acmodel(obss)

            approx_kl = (exps.log_prob -
                         dist.log_prob(exps.action)).mean().item()
            adv_mean = exps.advantage.mean().item()
            adv_max = exps.advantage.max().item()
            adv_min = exps.advantage.min().item()
            adv_std = exps.advantage.std().item()

            # standard deviation of values
            value_std = value.std().item()

        logs.update({
            "entropy": entropy.item(),
            "value": value.mean().item(),
            "value_std": value_std,
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss.item(),
            "grad_norm": update_grad_norm.item(),
            "adv_max": adv_max,
            "adv_min": adv_min,
            "adv_mean": adv_mean,
            "adv_std": adv_std,
            "kl": approx_kl,
        })
        return logs
示例#3
0
 def finalize(self, status: str) -> None:
     # upload all checkpoints from saving dir
     if self._log_model:
         wandb.save(os.path.join(self.save_dir, "*.ckpt"))
示例#4
0
        opt.zero_grad()

        log = {}

        if i % 10 == 0:
            print(epoch, i, f'loss - {loss.item()}')

            log = {**log, 'epoch': epoch, 'iter': i, 'loss': loss.item()}

        if i % 100 == 0:
            sample_text = text[:1]
            token_list = sample_text.masked_select(sample_text != 0).tolist()
            decoded_text = tokenizer.decode(token_list)

            image = dalle.generate_images(
                text[:1],
                mask=mask[:1],
                filter_thres=0.9  # topk sampling at 0.9
            )

            save_model(f'./dalle.pt')
            wandb.save(f'./dalle.pt')

            log = {**log, 'image': wandb.Image(image, caption=decoded_text)}

        wandb.log(log)

save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt')
wandb.finish()
def main():

    opt = parse_option()
    wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
    wandb.config.update(opt)
    wandb.save('*.py')
    wandb.run.save()

    train_loader, val_loader, meta_testloader, meta_valloader, n_cls, no_sample = get_dataloaders(
        opt)
    # model
    model = create_model(opt.model,
                         n_cls,
                         opt.dataset,
                         n_trans=opt.trans,
                         embd_sz=opt.memfeature_size)
    wandb.watch(model)

    # optimizer
    if opt.adam:
        print("Adam")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        print("SGD")
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    MemBank = np.random.randn(no_sample, opt.memfeature_size)
    MemBank = torch.tensor(MemBank, dtype=torch.float).cuda()
    MemBankNorm = torch.norm(MemBank, dim=1, keepdim=True)
    MemBank = MemBank / (MemBankNorm + 1e-6)

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):
        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss, MemBank = train(epoch, train_loader, model,
                                               criterion, optimizer, opt,
                                               MemBank)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        val_acc, val_acc_top5, val_loss = 0, 0, 0  #validate(val_loader, model, criterion, opt)

        #validate
        start = time.time()
        meta_val_acc, meta_val_std = 0, 0  #meta_test(model, meta_valloader)
        test_time = time.time() - start
        print(
            'Meta Val Acc : {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(
                meta_val_acc, meta_val_std, test_time))

        #evaluate
        start = time.time()
        meta_test_acc, meta_test_std = 0, 0  #meta_test(model, meta_testloader)
        test_time = time.time() - start
        print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.
              format(meta_test_acc, meta_test_std, test_time))

        # regular saving
        if epoch % opt.save_freq == 0 or epoch == opt.epochs:
            print('==> Saving...')
            state = {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'model': model.state_dict(),
            }
            save_file = os.path.join(opt.save_folder,
                                     'model_' + str(wandb.run.name) + '.pth')
            torch.save(state, save_file)

            #wandb saving
            torch.save(state, os.path.join(wandb.run.dir, "model.pth"))

        wandb.log({
            'epoch': epoch,
            'Train Acc': train_acc,
            'Train Loss': train_loss,
            'Val Acc': val_acc,
            'Val Loss': val_loss,
            'Meta Test Acc': meta_test_acc,
            'Meta Test std': meta_test_std,
            'Meta Val Acc': meta_val_acc,
            'Meta Val std': meta_val_std
        })

    #final report
    print("GENERATING FINAL REPORT")
    generate_final_report(model, opt, wandb)

    #remove output.txt log file
    output_log_file = os.path.join(wandb.run.dir, "output.log")
    if os.path.isfile(output_log_file):
        os.remove(output_log_file)
    else:  ## Show an error ##
        print("Error: %s file not found" % output_log_file)
示例#6
0
文件: train.py 项目: dedbox/TOAD-GAN
def train(real, opt):
    """ Wrapper function for training. Calculates necessary scales then calls train_single_scale on each. """
    generators = []
    noise_maps = []
    noise_amplitudes = []

    if opt.game == 'mario':
        token_group = MARIO_TOKEN_GROUPS
    else:  # if opt.game == 'mariokart':
        token_group = MARIOKART_TOKEN_GROUPS

    scales = [[x, x] for x in opt.scales]
    opt.num_scales = len(scales)

    if opt.game == 'mario':
        scaled_list = special_mario_downsampling(opt.num_scales, scales, real,
                                                 opt.token_list)
    else:  # if opt.game == 'mariokart':
        scaled_list = special_mariokart_downsampling(opt.num_scales, scales,
                                                     real, opt.token_list)

    reals = [*scaled_list, real]

    # If (experimental) token grouping feature is used:
    if opt.token_insert >= 0:
        reals = [(token_to_group(r, opt.token_list, token_group)
                  if i < opt.token_insert else r) for i, r in enumerate(reals)]
        reals.insert(
            opt.token_insert,
            token_to_group(reals[opt.token_insert], opt.token_list,
                           token_group))
    input_from_prev_scale = torch.zeros_like(reals[0])

    stop_scale = len(reals)
    opt.stop_scale = stop_scale

    # Log the original input level as an image
    img = opt.ImgGen.render(one_hot_to_ascii_level(real, opt.token_list))
    wandb.log({"real": wandb.Image(img)}, commit=False)
    os.makedirs("%s/state_dicts" % (opt.out_), exist_ok=True)

    # Training Loop
    divergences = []
    for current_scale in range(0, stop_scale):
        opt.outf = "%s/%d" % (opt.out_, current_scale)
        try:
            os.makedirs(opt.outf)
        except OSError:
            pass

        # If we are seeding, we need to adjust the number of channels
        if current_scale < (opt.token_insert + 1):  # (stop_scale - 1):
            opt.nc_current = len(token_group)

        # Initialize models
        D, G = init_models(opt)
        # If we are seeding, the weights after the seed need to be adjusted
        if current_scale == (opt.token_insert + 1):  # (stop_scale - 1):
            D, G = restore_weights(D, G, current_scale, opt)

        # Actually train the current scale
        z_opt, input_from_prev_scale, G, divs = train_single_scale(
            D, G, reals, generators, noise_maps, input_from_prev_scale,
            noise_amplitudes, opt)

        # Reset grads and save current scale
        G = reset_grads(G, False)
        G.eval()
        D = reset_grads(D, False)
        D.eval()

        generators.append(G)
        noise_maps.append(z_opt)
        noise_amplitudes.append(opt.noise_amp)
        divergences.append(divs)

        torch.save(noise_maps, "%s/noise_maps.pth" % (opt.out_))
        torch.save(generators, "%s/generators.pth" % (opt.out_))
        torch.save(reals, "%s/reals.pth" % (opt.out_))
        torch.save(noise_amplitudes, "%s/noise_amplitudes.pth" % (opt.out_))
        torch.save(opt.num_layer, "%s/num_layer.pth" % (opt.out_))
        torch.save(opt.token_list, "%s/token_list.pth" % (opt.out_))
        wandb.save("%s/*.pth" % opt.out_)

        torch.save(G.state_dict(),
                   "%s/state_dicts/G_%d.pth" % (opt.out_, current_scale))
        wandb.save("%s/state_dicts/*.pth" % opt.out_)

        del D, G

    torch.save(torch.tensor(divergences), "%s/divergences.pth" % opt.out_)

    return generators, noise_maps, reals, noise_amplitudes
示例#7
0
    def __init__(
        self,
        monitor='val_loss',
        verbose=0,
        mode='auto',
        save_weights_only=False,
        log_weights=False,
        log_gradients=False,
        save_model=True,
        training_data=None,
        validation_data=None,
        labels=[],
        data_type=None,
        predictions=36,
        generator=None,
        input_type=None,
        output_type=None,
        log_evaluation=False,
        validation_steps=None,
        class_colors=None,
    ):
        """Constructor.

        # Arguments
            monitor: quantity to monitor.
            mode: one of {auto, min, max}.
                'min' - save model when monitor is minimized
                'max' - save model when monitor is maximized
                'auto' - try to guess when to save the model
            save_weights_only: if True, then only the model's weights will be
                saved (`model.save_weights(filepath)`), else the full model
                is saved (`model.save(filepath)`).
            save_model:
                True - save a model when monitor beats all previous epochs
                False - don't save models
            log_weights: if True save the weights in wandb.history
            log_gradients: if True log the training gradients in wandb.history
            training_data: tuple (X,y) needed for calculating gradients
            labels: list of labels to convert numeric output to if you are building a
                multiclass classifier.  If you are making a binary classifier you can pass in
                a list of two labels ["label for false", "label for true"]
            predictions: the number of predictions to make each epic if data_type is set, max is 100.
            generator: a generator to use for making predictions
            input_type: the type of the model input. can be one of:
                (label, image, segmentation_mask).
            output_type: the type of the model output. can be one of:
                (label, image, segmentation_mask).
            log_evaluation: if True save a dataframe containing the full
                validation results at the end of training.
            validation_steps: if `validation_data` is a generator, how many
                steps to run the generator for the full validation set.
            class_colors: if the input or output is a segmentation mask, an array
                containing an rgb tuple (range 0.-1.) for each class.
        """
        if wandb.run is None:
            raise wandb.Error(
                'You must call wandb.init() before WandbCallback()')

        self.validation_data = None
        # This is kept around for legacy reasons
        if validation_data is not None:
            if is_generator_like(validation_data):
                generator = validation_data
            else:
                self.validation_data = validation_data

        self.labels = labels
        self.predictions = min(predictions, 100)

        self.monitor = monitor
        self.verbose = verbose
        self.save_weights_only = save_weights_only

        wandb.save('model-best.h5')
        self.filepath = os.path.join(wandb.run.dir, 'model-best.h5')
        self.save_model = save_model
        self.log_weights = log_weights
        self.log_gradients = log_gradients
        self.training_data = training_data
        self.generator = generator
        self._graph_rendered = False

        self.input_type = input_type or data_type
        self.output_type = output_type
        self.log_evaluation = log_evaluation
        self.validation_steps = validation_steps
        self.class_colors = np.array(
            class_colors) if class_colors is not None else None

        if self.training_data:
            if len(self.training_data) != 2:
                raise ValueError("training data must be a tuple of length two")

        # From Keras
        if mode not in ['auto', 'min', 'max']:
            print('WandbCallback mode %s is unknown, '
                  'fallback to auto mode.' % (mode))
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = operator.lt
            self.best = float('inf')
        elif mode == 'max':
            self.monitor_op = operator.gt
            self.best = float('-inf')
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = operator.gt
                self.best = float('-inf')
            else:
                self.monitor_op = operator.lt
                self.best = float('inf')
示例#8
0
        })

        history["loss-mu"][epoch] = np.mean(epoch_loss)
        history["loss-std"][epoch] = np.std(epoch_loss)
        history["td-loss-mu"][epoch] = np.mean(epoch_td_loss)
        history["td-loss-std"][epoch] = np.std(epoch_td_loss)
        history["ntd-loss-mu"][epoch] = np.mean(epoch_ntd_loss)
        history["ntd-loss-std"][epoch] = np.std(epoch_ntd_loss)
        history["margin-loss-mu"][epoch] = np.mean(epoch_margin_loss)
        history["margin-loss-std"][epoch] = np.std(epoch_margin_loss)
        history["l2-loss-mu"][epoch] = np.mean(epoch_l2_loss)
        history["l2-loss-std"][epoch] = np.std(epoch_l2_loss)

    # Save finished model.
    torch.save(online_model, './dqfd_{}.pkl'.format(args.level))
    wandb.save('./dqfd_{}.pkl'.format(args.level))

    env = retro.make('SonicTheHedgehog-Genesis', state=args.level)
    scores = utils.play_evaluation_games(
        env,
        online_model,
        state_transformer=utils.torchify_state,
        action_transformer=lambda x: utils.decoding_action_transformer(
            x, decoding),
        n_games=20,
        rnd_steps=50,
        max_frames=1000)
    wandb.log({
        'eval_mean': np.mean(scores),
        'eval_std': np.std(scores),
        'eval_min': np.min(scores),
示例#9
0
def train(args, trainer, task, epoch_itr, force_refine_step=None):
    """Train the model for one epoch."""

    # Update parameters every N batches
    def is_better(a, b):
        return a > b if args.maximize_best_checkpoint_metric else a < b

    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(args, "progressive") and args.progressive:
        task.dataset("train").set_random_refine_step(
            args.refinetot, force_refine_step=force_refine_step)
    last_samples = None
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if samples is None or len(samples) == 0:
            sys.stderr.write("Empty sample detected\n")
            sys.stderr.flush()
            samples = last_samples
        else:
            last_samples = samples
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue
        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args,
                                    trainer,
                                    task,
                                    epoch_itr,
                                    valid_subsets,
                                    force_refine_step=force_refine_step)
            # if distributed_utils.is_master(args):
            #     print("saving:", trainer.get_num_updates())
            #     nsml.save(str(trainer.get_num_updates()))
            if not hasattr(checkpoint_utils.save_checkpoint,
                           'best') or is_better(
                               valid_losses[0],
                               checkpoint_utils.save_checkpoint.best):
                if distributed_utils.is_master(args):
                    print("saving checkpoint ...")
                    sys.stdout.flush()
                    if HAS_NSML:
                        nsml.save("best")
                    else:
                        torch.save({"model": trainer.get_model().state_dict()},
                                   "/tmp/best.pt")
                    if HAS_WANDB:
                        wandb.save("/tmp/best.pt")
                    sys.stdout.flush()
                checkpoint_utils.save_checkpoint.best = valid_losses[0]

        if args.decoder_wise_training and update_num_to_refine_step(
                num_updates) != force_refine_step:
            if HAS_NSML:
                nsml.load("best")
            else:
                # Retrieve the model
                if distributed_utils.is_master(args):
                    state = torch.load("/tmp/best.pt", map_location="cpu")
                    trainer.model.load_state_dict(state["model"])
                # Sync
                assert isinstance(trainer.model,
                                  parallel.DistributedDataParallel)
                if isinstance(trainer.model, parallel.DistributedDataParallel):
                    trainer.model._sync_params()

            checkpoint_utils.save_checkpoint.best = 0.
            force_refine_step = update_num_to_refine_step(num_updates)
            trainer.criterion.pool.clear()
            print("| Start refinement step:", force_refine_step)

        if num_updates >= max_update:
            break

        if hasattr(args, "progressive") and args.progressive:
            task.dataset("train").set_random_refine_step(
                args.refinetot, force_refine_step=force_refine_step)

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#10
0
                    'epoch': epoch,
                    'iterations': (epoch + 1) * len(train_loader),
                    'best_loss': best_loss,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'AMPScaler': scaler.state_dict()
                }
                torch.save(checkpoint,
                           os.path.join(save_path, trainID + ".pth.tar"))
                if args.modelid != 9:
                    torch.onnx.export(model,
                                      images,
                                      trainID + ".onnx",
                                      input_names=["LRCurrTP"],
                                      output_names=["SuperResolvedCurrTP"])
                    wandb.save(trainID + ".onnx")

            tb_writer.add_scalar('Train/EpochLoss', loss_reducer(train_loss),
                                 epoch)
            wandb.log({"TrainEpochLoss":
                       loss_reducer(train_loss)})  #, step=epoch)

            #Validate
            if val_loader:
                model.eval()
                with torch.no_grad():
                    runningLoss = []
                    val_loss = []
                    runningAcc = []
                    val_acc = []
                    print('Epoch ' + str(epoch) + ': Val')
示例#11
0
 def save_model(self, model: Tree2Seq, output_name: str,
                configuration: Dict, **kwargs: Dict) -> str:
     checkpoint_path = super().save_model(model, output_name, configuration,
                                          **kwargs)
     wandb.save(checkpoint_path)
     return checkpoint_path
示例#12
0
    def train(self):
        wandb.init(name=cfg.EXP_NAME, project='AttnGAN', config=cfg, dir='../logs')

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD =  \
            self.apply_apex(text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD)
        # add watch
        wandb.watch(netG)
        for D in netsD:
            wandb.watch(D)

        avg_param_G = copy_G_params(netG)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        log_dict = {}
        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    if cfg.APEX:
                        from apex import amp
                        with amp.scale_loss(errD, optimizersD[i], loss_id=i) as errD_scaled:
                            errD_scaled.backward()
                    else:
                        errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    log_name = 'errD_{}'.format(i)
                    log_dict[log_name] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, G_log_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                log_dict.update(G_log_dict)
                log_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                if cfg.APEX:
                    from apex import amp
                    with amp.scale_loss(errG_total, optimizerG, loss_id=len(netsD)) as errG_scaled:
                        errG_scaled.backward()
                else:
                    errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                wandb.log(log_dict)
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                    wandb.save('logs.ckpt')
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
示例#13
0
文件: run.py 项目: parkerhuynh/kalapa
def main(args):
    wandb.init(project="kalapa")
    seed = args.seed
    os.system(f"git commit -am \"{args.message}\"")
    code_version = os.popen('git rev-parse HEAD').read().strip()
    wandb.log({
        "user": "******",
        "seed": seed,
        "code_version": code_version,
        "data_version": args.data_version,
        "weight_version": args.weight_version
    })
    train = pd.read_csv(f"../../data/kalapa/{args.data_version}/train.csv")
    #new_train = pd.read_csv(f"../../data/kalapa/{args.data_version}/new_train.csv")
    test = pd.read_csv(f"../../data/kalapa/{args.data_version}/test.csv")
    cols = train.iloc[:, 2:].columns

    def to_category(df_fe):
        for col in cols:
            if df_fe[col].dtype.name == "object":
                df_fe[col] = df_fe[col].astype('category')
        return df_fe

    train = to_category(train)
    test = to_category(test)

    lgbm_param = {'boosting_type': 'gbdt', \
                  'colsample_bytree': 0.6602479798930369, \
                  'is_unbalance': True, \
                  'learning_rate': 0.01, \
                  'max_depth': 15, \
                  'metric': 'auc', \
                  'min_child_samples': 25, \
                  'num_leaves': 60, \
                  'objective': 'binary', \
                  'reg_alpha': 0.4693391197064131, \
                  'reg_lambda': 0.16175478669541327, \
                  'subsample_for_bin': 60000}
    NUM_BOOST_ROUND = 10000

    def kfold(train_fe, test_fe):
        #nonlocal col2
        y_label = train_fe.label
        seeds = np.random.randint(0, 10000, 1)
        preds = 0
        feature_important = True
        avg_train_gini = 0
        avg_val_gini = 0

        for s in seeds:
            skf = StratifiedKFold(n_splits=5, random_state=6484, shuffle=True)
            lgbm_param['random_state'] = 6484
            seed_train_gini = 0
            seed_val_gini = 0
            for i, (train_idx, val_idx) in enumerate(
                    skf.split(np.zeros(len(y_label)), y_label)):
                X_train, X_val = train_fe.iloc[train_idx].drop(
                    ["id", "label"],
                    1), train_fe.iloc[val_idx].drop(["id", "label"], 1)
                #new_X_train = new_train_fe.drop(["id", "label"], 1)
                #X_train = pd.concat([X_train,new_X_train], axis = 0)

                #X_train = to_category(X_train)
                #for col in col2:
                #X_train[col] = X_train[col].astype('category')

                y_train, y_val = y_label.iloc[train_idx], y_label.iloc[val_idx]
                #y_train = pd.concat([y_train, new_train_fe.label], axis = 0)

                lgb_train = lgb.Dataset(X_train, y_train)
                lgb_eval = lgb.Dataset(X_val, y_val)

                evals_result = {}
                model = lgb.train(lgbm_param,
                                  lgb_train,
                                  num_boost_round=NUM_BOOST_ROUND,
                                  early_stopping_rounds=400,
                                  feval=lgb_gini,
                                  verbose_eval=200,
                                  evals_result=evals_result,
                                  valid_sets=[lgb_train, lgb_eval])

                seed_train_gini += model.best_score["training"][
                    "gini"] / skf.n_splits
                seed_val_gini += model.best_score["valid_1"][
                    "gini"] / skf.n_splits

                avg_train_gini += model.best_score["training"]["gini"] / (
                    len(seeds) * skf.n_splits)
                avg_val_gini += model.best_score["valid_1"]["gini"] / (
                    len(seeds) * skf.n_splits)
                if feature_important is None:
                    feature_important = model.feature_importance() / (
                        len(seeds) * skf.n_splits)
                else:
                    feature_important += model.feature_importance() / (
                        len(seeds) * skf.n_splits)

                pred = model.predict(test_fe.drop(["id"], 1))
                preds += pred / (skf.n_splits * len(seeds))

                print("Fold {}: {}/{}".format(
                    i, model.best_score["training"]["gini"],
                    model.best_score["valid_1"]["gini"]))
                log = {
                    "gini_train": model.best_score["training"]["gini"],
                    "gini": model.best_score["valid_1"]["gini"],
                    "epoch": NUM_BOOST_ROUND
                }
                wandb.log(log)
            print("Seed {}: {}/{}".format(s, seed_train_gini, seed_val_gini))

        print("-" * 30)
        print("Avg train gini: {}".format(avg_train_gini))
        print("Avg valid gini: {}".format(avg_val_gini))
        wandb.log({"gini": avg_val_gini})
        print("=" * 30)
        return preds

    preds = kfold(train, test)
    test["label"] = preds
    test[["id", "label"]].to_csv("test_preds.csv", index=False)
    wandb.save("test_preds.csv")
示例#14
0
def train(start_epoch, num_epochs):
    early_stopping = EarlyStopping(opt, verbose=True)
    for epoch in tqdm(range(start_epoch, num_epochs)):
        torch.cuda.empty_cache()
        print("Epoch: %d" % epoch)
        train_losses = []
        model.train()
        past = time.time()
        for batch_idx, (inputs, targets) in enumerate(train_data_loader):
            # models.zero_grad()
            # optimizer.zero_grad()#当optimizer=optim.Optimizer(models.parameters())时,两者等效
            # 梯度清零
            optimizer.zero_grad()

            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            # if (batch_idx + 1) % num == 0:
            # 	print(batch_idx + 1, len(dataloader), 'Loss: %.3f' % (train_loss / num))
            # 	train_loss = 0
        now = time.time()
        train_loss = np.mean(np.array(train_losses))
        print(epoch, "loss:%.3f,time:%.2fs" % (train_loss, now - past))
        writer.add_scalar("train_loss", train_loss, epoch)
        wandb.log({"train_loss": train_loss}, step=epoch)
        train_loss = 0
        # checkpoint = {
        #     "model_state_dict": models.module.state_dict(),
        #     "opt_state_dict": optimizer.state_dict(),
        #     "epoch": epoch,
        # }

        scheduler.step()
        # the end of one epoch
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch_idx, (inputs,
                            targets) in enumerate(validation_data_loader):
                inputs, targets = inputs.to(device), targets.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_losses.append(loss.item())
        val_loss = np.mean(np.array(val_losses))
        wandb.log({"val_loss": val_loss}, step=epoch)

        #####some testing#####
        print("xxxxxxx".format(xxxxxxx))
        #####some logging#####
        prefix = opt.path_to_checkpoint + opt.hidden_size + "_"
        file = prefix + "xxx_xxx_xxx.pt"
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            opt.path_to_checkpoint + "%d_%d_%d_%d_%d_model.onnx" % (
                opt.value_dim,
                opt.moment_dim,
                opt.hidden_size,
                opt.label_dim,
                opt.LSTM_num_layers,
            ),
        )
        wandb.save("mymodel.h5")
        early_stopping(val_loss, model, opt)
        if early_stopping.early_stop:
            print("Early stopping")
            break
示例#15
0
                break
        if args.kle_rollback:
            if (b_logprobs[minibatch_ind] - agent.get_action(
                    b_obs[minibatch_ind],
                    b_actions.long()[minibatch_ind].T,
                    b_invalid_action_masks[minibatch_ind],
                    envs)[1]).mean() > args.target_kl:
                agent.load_state_dict(target_agent.state_dict())
                break

    ## CRASH AND RESUME LOGIC:
    if args.prod_mode:
        if not os.path.exists(f"models/{experiment_name}"):
            os.makedirs(f"models/{experiment_name}")
        torch.save(agent.state_dict(), f"{wandb.run.dir}/agent.pt")
        wandb.save(f"agent.pt")

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/update", update, global_step)
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
    writer.add_scalar("losses/entropy", entropy.mean().item(), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
    print("SPS:", int(global_step / (time.time() - start_time)))

envs.close()
writer.close()
示例#16
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_init_hvd(args)

    # Print args
    print(args)

    # if not HAS_NSML:
    #     args.data[0] = args.data[0].replace("/train", "")

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    if args.train_decoder_only:
        for name, param in model.named_parameters():
            if "decoder" not in name:
                param.requires_grad_(False)

    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Setup session
    if HAS_WANDB and distributed_utils.is_master(args):
        wandb.init(project="cmlm", config=args)
        wandb.watch(model)

    # Load pre-trained model
    data_token = args.data[0].split("/")[-1]
    if "bert" in args.arch:
        pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(
            DATASET_PATH,
            data_token.split(".")[-1].replace("-", "_"))
        if not HAS_NSML:
            pretrained_path = pretrained_path.replace("/train", "")
        print("| loading", pretrained_path)
        state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
        model.load_state_dict(state["model"], strict=True)
        baseline_model = task.build_model(args)
        baseline_model.load_state_dict(state["model"], strict=True)
        if torch.cuda.is_available():
            baseline_model.cuda()
        task.set_baseline_model(baseline_model)

    if not args.masking and HAS_NSML:

        def nsml_bind(model):
            def save(dir_path):
                state = {
                    'model': model.state_dict(),
                }
                torch.save(state, os.path.join(dir_path, 'best.pt'))

            def load(dir_path):
                state = torch.load(os.path.join(dir_path, 'best.pt'),
                                   map_location="cpu")
                model.load_state_dict(state['model'], strict=False)
                model.cuda()
                print('model loaded!')

            nsml.bind(save=save, load=load)

        nsml_bind(model)

    if args.load:
        print("loading model from session", args.load)
        if args.load.startswith("nsml://"):
            session = args.load.replace("nsml://", "")
        if ".pt" in session:
            session = session.replace(".pt", "")
            session, checkpoint_name = session.rsplit("/", 1)
        else:
            checkpoint_name = "best"
        if "-" in checkpoint_name:
            start, end = checkpoint_name.replace("epoch", "").split("-")
            checkpoints = [
                "epoch{}".format(i) for i in range(int(start),
                                                   int(end) + 1)
            ]
            print("| checkpoint average:", checkpoints)
            state_dict = None

            def load(dir_path):
                nonlocal state_dict, checkpoints
                state = torch.load(os.path.join(dir_path, 'best.pt'))
                model_state = state["model"]
                for k in model_state:
                    model_state[k] = model_state[k] / float(len(checkpoints))
                if state_dict is None:
                    state_dict = model_state
                else:
                    for k in state_dict:
                        state_dict[k] += model_state[k]
                print("checkpoint loaded")

            for checkpoint_name in checkpoints:
                nsml.load(checkpoint_name, load_fn=load, session=session)
            model.load_state_dict(state_dict)
        else:

            def load(dir_path):
                state = torch.load(os.path.join(dir_path, 'best.pt'))
                state_dict = state["model"]
                model.load_state_dict(state_dict)
                print("loaded")

            nsml.load(checkpoint_name, load_fn=load, session=session)

    # Prepare for decoder wise training
    if args.decoder_wise_training:
        print("| Decoder wise training, start refinement step 0")
        progressive_training_step = 0
        assert args.ddp_backend == "c10d"
    else:
        progressive_training_step = None

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    if hasattr(args, "progressive") and args.progressive:
        for i in range(args.refinetot if not getattr(args, "pnet", False) else
                       args.refinetot - 1):
            print("validating for refine step", i)
            validate(args,
                     trainer,
                     task,
                     epoch_itr,
                     valid_subsets,
                     force_refine_step=i)
        print("---")
    validate(args, trainer, task, epoch_itr, valid_subsets)
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args,
              trainer,
              task,
              epoch_itr,
              force_refine_step=progressive_training_step)
        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(
                args,
                trainer,
                task,
                epoch_itr,
                valid_subsets,
                force_refine_step=progressive_training_step)
        else:
            valid_losses = [None]

        if args.decoder_wise_training:
            progressive_training_step = update_num_to_refine_step(
                trainer.get_num_updates())

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            if HAS_NSML:
                if distributed_utils.is_master(args):
                    print("nsml save for epoch", epoch_itr.epoch)
                    nsml.save("epoch{}".format(epoch_itr.epoch))
            else:
                torch.save({"model": trainer.get_model().state_dict()},
                           "/tmp/epoch{}.pt".format(epoch_itr.epoch))
                if HAS_WANDB:
                    wandb.save("/tmp/epoch{}.pt".format(epoch_itr.epoch))
                # checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
def run(arguments, tag_in_vcs=False) -> None:
    azure_info_path = arguments.get('--azure-info', None)
    testrun = arguments.get('--testrun')
    max_files_per_dir = arguments.get('--max-files-per-dir')

    dir_path = Path(__file__).parent.absolute()

    # if you do not pass arguments for train/valid/test data default to files checked into repo.
    if not arguments['TRAIN_DATA_PATH']:
        arguments['TRAIN_DATA_PATH'] = str(dir_path / 'data_dirs_train.txt')
        arguments['VALID_DATA_PATH'] = str(dir_path / 'data_dirs_valid.txt')
        arguments['TEST_DATA_PATH'] = str(dir_path / 'data_dirs_test.txt')

    train_data_dirs = test.expand_data_path(arguments['TRAIN_DATA_PATH'],
                                            azure_info_path)
    valid_data_dirs = test.expand_data_path(arguments['VALID_DATA_PATH'],
                                            azure_info_path)
    test_data_dirs = test.expand_data_path(arguments['TEST_DATA_PATH'],
                                           azure_info_path)

    # default model save location
    if not arguments['SAVE_FOLDER']:
        arguments['SAVE_FOLDER'] = str(dir_path.parent /
                                       'resources/saved_models/')

    save_folder = arguments['SAVE_FOLDER']

    model_class = model_restore_helper.get_model_class_from_name(
        arguments['--model'])

    hyperparameters = model_class.get_default_hyperparameters()
    run_name = make_run_id(arguments)

    # make name of wandb run = run_id (Doesn't populate yet)
    hyperparameters['max_epochs'] = int(arguments.get('--max-num-epochs'))

    if testrun:
        hyperparameters['max_epochs'] = 2
        if not max_files_per_dir:
            max_files_per_dir = 1

    # override hyperparams if flag is passed
    hypers_override = arguments.get('--hypers-override')
    if hypers_override is not None:
        hyperparameters.update(json.loads(hypers_override))
    elif arguments.get('--hypers-override-file') is not None:
        with open(arguments.get('--hypers-override-file')) as f:
            hyperparameters.update(json.load(f))

    os.makedirs(save_folder, exist_ok=True)

    if tag_in_vcs:
        hyperparameters['git_commit'] = git_tag_run(run_name)

    # turns off wandb if you don't want to log anything
    if arguments.get('--dryrun'):
        os.environ["WANDB_MODE"] = 'dryrun'
    # save hyperparams to logging
    # must filter out type=set from logging when as that is not json serializable
    wandb.init(name=run_name,
               config={
                   k: v
                   for k, v in hyperparameters.items()
                   if not isinstance(v, set)
               })
    wandb.config.update({
        'model-class':
        arguments['--model'],
        'train_folder':
        str(train_data_dirs),
        'valid_folder':
        str(valid_data_dirs),
        'save_folder':
        str(save_folder),
        'test_folder':
        str(test_data_dirs),
        'CUDA_VISIBLE_DEVICES':
        os.environ.get("CUDA_VISIBLE_DEVICES", 'Not Set'),
        'run-name':
        arguments.get('--run-name'),
        'CLI-command':
        ' '.join(sys.argv)
    })

    if arguments.get('--evaluate-model'):
        model_path = RichPath.create(arguments['--evaluate-model'])
    else:
        model_path = run_train(model_class,
                               train_data_dirs,
                               valid_data_dirs,
                               save_folder,
                               hyperparameters,
                               azure_info_path,
                               run_name,
                               arguments['--quiet'],
                               max_files_per_dir=max_files_per_dir,
                               parallelize=not (arguments['--sequential']))

    wandb.config['best_model_path'] = str(model_path)
    wandb.save(str(model_path.to_local_path()))

    # only limit files in test run if `--testrun` flag is passed by user.
    if testrun:
        compute_evaluation_metrics(model_path, arguments, azure_info_path,
                                   valid_data_dirs, test_data_dirs,
                                   max_files_per_dir)
    else:
        compute_evaluation_metrics(model_path, arguments, azure_info_path,
                                   valid_data_dirs, test_data_dirs)
def run_experiment(dataset, network, model, proj_name, epoch, train_args):

    print(
        f"Running experiment with network '{network}' and dataset '{dataset}''"
    )
    datasets_module = importlib.import_module(
        "lab1.language_model.datasets.house_pred")
    dataset_class_ = getattr(datasets_module, dataset)
    # dataset_args = experiment_config.get("dataset_args", {})

    models_module = importlib.import_module("lab1.language_model.models.base2")
    model_class_ = getattr(models_module, model)

    networks_module = importlib.import_module(
        "lab1.language_model.networks.mlp")
    network_fn = getattr(networks_module, network)
    # save_net_artifact(project_name=proj_name, network=network_fn())

    # network_args = experiment_config.get("network_args", {})

    # mlflow.set_tracking_uri("sqlite:///mlruns.db")
    model = model_class_(dataset_cls=dataset_class_, network_fn=network_fn)
    # input_schema = Schema([TensorSpec(type=np.dtype(np.float32), shape=(-1, 13), name="house_attribs")])
    # output_schema = Schema([TensorSpec(type=np.dtype(np.float32), shape=(-1, 1), name="predicted house price")])
    # signature = ModelSignature(inputs=input_schema, outputs=output_schema)
    # input_example = np.array([[1., 2.5, 3. , 1.7, 2.1, 1.3, .5, .75, .89, 1.9, 2.15, 2.2, .6]])
    # mlflow.pyfunc.save_model(path="my_model", python_model=model, signature=signature, input_example=input_example )

    config = dict(dataset=dataset,
                  network=network,
                  model=model,
                  epoch=epoch,
                  train_args=train_args)

    net_config = dict(input_shape=(13, ),
                      output_shape=(1),
                      layer_size=64,
                      dropout_amount=0.2,
                      num_layers=3)

    with wandb.init(project=proj_name, config=config) as run:
        config = wandb.config

        # Add model artifact
        model_artifact = wandb.Artifact("convnet",
                                        type="model",
                                        description="Simple AlexNet style CNN",
                                        metadata=dict(net_config))
        model.network.save("initialized_model.keras")
        model_artifact.new_file("initialized_model.keras")
        wandb.save("initialized_model.keras")
        run.log_artifact(model_artifact)

        # Add data artifact
        raw_data = wandb.Artifact(
            "mnist-raw",
            type="dataset",
            description="sklearn.datasets.load_boston",
            metadata={
                "source": "keras.datasets.mnist",
                #"size (rows)": [model.dataset.X.shape[0]]
            })
        with raw_data.new_file("raw" + ".npz", mode="wb") as file:
            np.savez(file, x=model.data.X, y=model.data.y)
        run.log_artifact(raw_data)

        preprocessed_data = wandb.Artifact(
            "mnist-processed",
            type="dataset",
            description="sklearn.datasets.load_boston",
            metadata={
                "source": "keras.datasets.mnist",
                #"size (rows)": [model.dataset.X.shape[0]]
            })
        with preprocessed_data.new_file("training" + ".npz",
                                        mode="wb") as file:
            np.savez(file, x=model.data.X_tr, y=model.data.y_tr)
        run.log_artifact(preprocessed_data)

        model.fit(dataset=config.dataset, callbacks=[WandbCallback()])
示例#19
0
                               filter_text=config["filter_text"])

    config['train_loader'] = data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        num_workers=config['num_workers'],
        collate_fn=train_dataset.get_collate_fn(),
        shuffle=True,
        pin_memory=True)
    config['val_loader'] = data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        num_workers=config['num_workers'],
        collate_fn=val_dataset.get_collate_fn())
    config['test_loader'] = data.DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        num_workers=config['num_workers'],
        collate_fn=test_dataset.get_collate_fn())

    try:
        trainer = TrainerUniter(config)
        trainer.train_main()
        wandb.save('vis_checkpoints/*', base_path="vis_checkpoints/")
        wandb.finish()
    except KeyboardInterrupt:
        LOGGER.warning(
            "Keyboard interrupt by user detected...\nClosing the tensorboard writer!"
        )
        config['writer'].close()
示例#20
0
文件: vqa.py 项目: j-min/VL-T5
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            best_valid = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLT5_VQA"
                else:
                    project_name = "T5_VQA"
            elif 'bart' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLBart_VQA"
                else:
                    project_name = "Bart_VQA"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0.,

            }

            quesid2ans = {}

            for step_i, batch in enumerate(self.train_loader):

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(
                            self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f}'
                    desc_str += f' | Loss {loss_meter.val:4f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()

            # Validation
            score_dict = self.evaluate(self.val_loader)

            if self.verbose:
                valid_score = score_dict['topk_score'] * 100.
                valid_score_raw = score_dict['overall']
                if valid_score_raw > best_valid or epoch == 0:
                    best_valid = valid_score_raw
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''
                log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (epoch, valid_score_raw, valid_score)
                log_str += "\nEpoch %d: Best Raw %0.2f\n" % (best_epoch, best_valid)

                wandb_log_dict = {}
                wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader)

                wandb_log_dict['Valid/score'] = valid_score

                wandb_log_dict['Valid/raw_score'] = score_dict['overall']
                for qtype, score in score_dict['perQuestionType'].items():
                    wandb_log_dict[f'Valid_Qtypes/{qtype}'] = score
                for atype, score in score_dict['perAnswerType'].items():
                    if atype == 'yes/no':
                        atype = 'yes_no'
                    wandb_log_dict[f'Valid_Atypes/{atype}'] = score

                wandb.log(wandb_log_dict, step=epoch)
                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

        # Test Set
        best_path = os.path.join(self.args.output, 'BEST')
        self.load(best_path)

        quesid2ans = self.predict(self.test_loader)

        if self.verbose:
            evaluator = self.test_loader.evaluator
            score_dict = evaluator.evaluate(quesid2ans)

            evaluator.dump_result(quesid2ans)

            acc_dict_all = evaluator.evaluate_raw(quesid2ans)
            acc_dict_answerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=True)
            acc_dict_unanswerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=False)

            wandb_log_dict = {}
            wandb_log_dict['Test/overall'] = acc_dict_all['overall']
            wandb_log_dict['Test/topk_optimal'] = acc_dict_answerable['overall']
            wandb_log_dict['Test/topk_not_optimal'] = acc_dict_unanswerable['overall']

            for qtype, score in acc_dict_all['perQuestionType'].items():
                wandb_log_dict[f'Test_Qtypes/{qtype}'] = score
            for atype, score in acc_dict_all['perAnswerType'].items():
                if atype == 'yes/no':
                    atype = 'yes_no'
                wandb_log_dict[f'Test_Atypes/{atype}'] = score

            print(wandb_log_dict)
            wandb.log(wandb_log_dict)

        if self.args.submit:
            dump_path = os.path.join(self.args.output, 'submit.json')
            self.predict(self.submit_test_loader, dump_path)

            wandb.save(dump_path, base_path=self.args.output)
            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()
示例#21
0
def train(config: BaseConfig, writer: SummaryWriter):
    memory = ReplayMemory(config.replay_memory_capacity)

    # create networks
    model = config.get_uniform_network().to(config.device)
    target_model = config.get_uniform_network().to(config.device)
    target_model.load_state_dict(model.state_dict())
    test_model = config.get_uniform_network().to(config.device)

    # create optimizers
    critic_optimizer = Adam([{
        'params': model.critic_1.parameters()
    }, {
        'params': model.critic_2.parameters()
    }],
                            lr=config.lr)
    policy_optimizer = Adam(model.actor.parameters(), lr=config.lr)

    # create envs
    env = config.new_game(seed=config.seed)
    test_env = config.new_game(seed=config.seed + 100)

    # training trackers
    total_env_steps = 0
    updates = 0
    best_test_score = float('-inf')

    # Fire!!
    for i_episode in itertools.count(1):
        done = False
        episode_steps, episode_reward, episode_repeats = 0, 0, []

        state = env.reset()
        while not done:
            epsilon = get_epsilon(config.max_epsilon, config.min_epsilon,
                                  total_env_steps, config.max_env_steps)

            if total_env_steps < config.start_step:
                action = env.action_space.sample()
                repeat_n = np.random.choice(model.action_repeats)
            else:
                with torch.no_grad():
                    # noisy action
                    state = torch.FloatTensor(state).unsqueeze(0).to(
                        config.device)
                    action = model.actor(state)
                    noise = Normal(torch.tensor([0.0]),
                                   torch.tensor([config.exploration_noise]))
                    action = action + noise.sample(
                        action.shape).squeeze(-1).to(config.device)
                    action = clip_action(action, config.action_space)

                    # epsilon-greedy repeat
                    if np.random.rand() <= epsilon:
                        repeat_idx = np.random.randint(
                            len(model.action_repeats))
                    else:
                        repeat_q = model.critic_1(state, action)
                        repeat_idx = repeat_q.argmax(1).item()

                state = state.data.cpu().numpy()[0]
                action = action.data.cpu().numpy()[0]
                repeat_n = model.action_repeats[repeat_idx]
            episode_repeats.append(repeat_n)

            # step
            step = 0
            discounted_reward_sum = 0
            next_states, rewards, terminals = [], [], []
            for repeat_i in range(1, repeat_n + 1):
                next_state, reward, done, info = env.step(action)
                discounted_reward_sum += (config.gamma**repeat_i) * (
                    reward / config.reward_scale_factor)
                episode_reward += reward

                # incr counters
                step += 1
                episode_steps += 1
                total_env_steps += 1

                # save data for each sub-repeat count
                if (repeat_i in model.action_repeats) or done:
                    next_states.append(next_state)
                    rewards.append(discounted_reward_sum)

                    # Ignore the "done" signal if it comes from hitting the time horizon.
                    terminal = 0 if (
                        ('TimeLimit.truncated' in info)
                        and info['TimeLimit.truncated']) else float(done)
                    terminals.append(terminal)

                # Test
                # Note : This is kept inside env step for-loop to keep test intervals sync. across multiple seeds.
                if total_env_steps % config.test_interval_steps == 0:
                    test_model.load_state_dict(model.state_dict())
                    test_output = test(test_env, test_model,
                                       config.test_episodes)
                    if test_output.score > best_test_score:
                        torch.save(test_model.state_dict(),
                                   config.best_model_path)

                    # Test Log
                    writer.add_scalar('test/score', test_output.score,
                                      total_env_steps)
                    writer.add_scalar('test/avg_action_repeats',
                                      test_output.avg_repeat, total_env_steps)
                    test_logger.info(
                        '#{} test score: {} avg_action_repeats:{}'.format(
                            total_env_steps, test_output.score,
                            test_output.avg_repeat))

                if done:
                    break

            # add random data to be masked during batch processing.
            next_state_mask = [1 for _ in range(len(next_states))]
            if len(next_states) < len(model.action_repeats):
                next_state_mask += [
                    0 for _ in range(
                        len(model.action_repeats) - len(next_state_mask))
                ]

                # Note: these values will be ignored during update
                terminals += [
                    float('-inf')
                    for _ in range(len(model.action_repeats) - len(terminals))
                ]
                next_states += [
                    np.ones(next_states[-1].shape) for _ in range(
                        len(model.action_repeats) - len(next_states))
                ]
                rewards += [
                    float('-inf')
                    for _ in range(len(model.action_repeats) - len(rewards))
                ]

            # Add to memory
            memory.push(state, action, rewards, next_states, next_state_mask,
                        terminals)
            state = next_state

            # update network
            if len(memory) > config.batch_size:
                critic_1_loss, critic_2_loss, policy_loss = 0, 0, 0
                update_count = config.updates_per_step * step
                for i in range(update_count):
                    loss = update_params(model, target_model, critic_optimizer,
                                         policy_optimizer, memory, updates,
                                         config)
                    critic_1_loss += loss[0]
                    critic_2_loss += loss[1]
                    policy_loss += loss[2]

                    updates += 1

                # Log
                writer.add_scalar('train/critic_1_loss',
                                  critic_1_loss / update_count,
                                  total_env_steps)
                writer.add_scalar('train/critic_2_loss',
                                  critic_2_loss / update_count,
                                  total_env_steps)
                writer.add_scalar('train/policy_loss',
                                  policy_loss / update_count, total_env_steps)

        # log episode data
        writer.add_scalar('data/eps_reward', episode_reward, total_env_steps)
        writer.add_scalar('data/eps_steps', episode_steps, total_env_steps)
        writer.add_scalar('data/eps_repeats',
                          np.array(episode_repeats).mean(), total_env_steps)
        writer.add_scalar('data/episodes', i_episode, total_env_steps)
        writer.add_scalar('data/epsilon', epsilon, total_env_steps)
        writer.add_scalar('train/updates', updates, total_env_steps)

        _msg = '#{} train score:{} eps steps: {} total steps: {} updates : {}'
        _msg = _msg.format(i_episode, round(episode_reward, 2), episode_steps,
                           total_env_steps, updates)
        train_logger.info(_msg)

        # save model
        if i_episode % config.save_model_freq == 0:
            torch.save(model.state_dict(), config.model_path)
            if config.use_wandb:
                import wandb
                wandb.save(config.model_path, policy='now')

        # check if max. env steps reached.
        if total_env_steps > config.max_env_steps:
            train_logger.info('max env. steps reached!!')
            break

    # save the last updated model
    torch.save(model.state_dict(), config.model_path)
示例#22
0
    def __init__(self,
                 monitor='val_loss',
                 verbose=0,
                 mode='auto',
                 save_weights_only=False,
                 log_weights=False,
                 log_gradients=False,
                 save_model=True,
                 training_data=None,
                 validation_data=None,
                 labels=[],
                 data_type=None,
                 predictions=36,
                 generator=None,
                 input_type=None,
                 output_type=None,
                 log_evaluation=False,
                 validation_steps=None,
                 class_colors=None,
                 log_batch_frequency=None,
                 log_best_prefix="best_",
                 save_graph=True):
        if wandb.run is None:
            raise wandb.Error(
                'You must call wandb.init() before WandbCallback()')

        self.validation_data = None
        # This is kept around for legacy reasons
        if validation_data is not None:
            if is_generator_like(validation_data):
                generator = validation_data
            else:
                self.validation_data = validation_data

        self.labels = labels
        self.predictions = min(predictions, 100)

        self.monitor = monitor
        self.verbose = verbose
        self.save_weights_only = save_weights_only
        self.save_graph = save_graph

        wandb.save('model-best.h5')
        self.filepath = os.path.join(wandb.run.dir, 'model-best.h5')
        self.save_model = save_model
        self.log_weights = log_weights
        self.log_gradients = log_gradients
        self.training_data = training_data
        self.generator = generator
        self._graph_rendered = False

        self.input_type = input_type or data_type
        self.output_type = output_type
        self.log_evaluation = log_evaluation
        self.validation_steps = validation_steps
        self.class_colors = np.array(
            class_colors) if class_colors is not None else None
        self.log_batch_frequency = log_batch_frequency
        self.log_best_prefix = log_best_prefix

        if self.training_data:
            if len(self.training_data) != 2:
                raise ValueError("training data must be a tuple of length two")

        # From Keras
        if mode not in ['auto', 'min', 'max']:
            print('WandbCallback mode %s is unknown, '
                  'fallback to auto mode.' % (mode))
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = operator.lt
            self.best = float('inf')
        elif mode == 'max':
            self.monitor_op = operator.gt
            self.best = float('-inf')
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = operator.gt
                self.best = float('-inf')
            else:
                self.monitor_op = operator.lt
                self.best = float('inf')
示例#23
0
            sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
            log["sample_per_sec"] = sample_per_sec
            print(epoch, i, f'sample_per_sec - {sample_per_sec}')

        if i == 201 and args.flops_profiler:
            raise StopIteration("Profiler has finished running. Stopping training early.")

        if distr_backend.is_root_worker():
            wandb.log(log)

    if LR_DECAY:
        distr_scheduler.step(avg_loss)

    save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
    
    if distr_backend.is_root_worker():
        # save trained model to wandb as an artifact every epoch's end

        model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
        model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
        run.log_artifact(model_artifact)

save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
if distr_backend.is_root_worker():
    wandb.save(DALLE_OUTPUT_FILE_NAME)
    model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
    model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
    run.log_artifact(model_artifact)

    wandb.finish()
示例#24
0
            # accuracy sketches
            _, max_idx = torch.max(pred_logits_second, dim=1)
            running_acc_sketches += torch.sum(max_idx == second_label).item()
            avg_sketches_val_acc = running_acc_sketches / items * 100

        avg_val_loss = val_total_loss / len(val_loader)

        if not args.debug:
            wandb.log({
                'val/loss': avg_val_loss,
                'val/acc flickr': avg_val_flicker_acc,
                'val/acc sketches': avg_sketches_val_acc,
                'epoch': epoch + 1
            })

            # checkpointing
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
            model_name = f"best_{contrastive_net.__class__.__name__}_contrastive.pth"
            torch.save(contrastive_net.state_dict(), model_name)
            if not args.debug:
                wandb.save(model_name)
            tag = '*'

            sys.stdout.write(
                f", Val[Loss: {avg_val_loss}, flickr acc: {avg_val_flicker_acc}, sketches acc: {avg_sketches_val_acc} {tag}"
            )
            sys.stdout.flush()
            sys.stdout.write('\n')
示例#25
0
def main(cfg: DictConfig):
    # instantiate Wandb Logger
    wandblogger = WandbLogger(project=cfg.general.project_name,
                              log_model=True,
                              name=cfg.training.job_name)
    # Log Hyper-parameters to Wandb
    wandblogger.log_hyperparams(cfg)

    # set random seeds so that results are reproducible
    seed_everything(cfg.training.random_seed)

    # generate a random idx for the job
    if cfg.training.unique_idx is None:
        cfg.training.unique_idx = generate_random_id()

    uq_id = cfg.training.unique_idx
    model_name = f"{cfg.training.encoder}-fold={cfg.training.fold}-{uq_id}"

    # Set up Callbacks to assist in Training
    cbs = [
        WandbTask(),
        DisableValidationBar(),
        LogInformationCallback(),
        LearningRateMonitor(logging_interval="step"),
    ]

    if cfg.training.patience is not None:
        cbs.append(
            EarlyStopping(monitor="valid/acc",
                          patience=cfg.training.patience,
                          mode="max"))

    checkpointCallback = ModelCheckpoint(
        monitor="valid/acc",
        save_top_k=1,
        mode="max",
    )
    # set up trainder kwargs
    kwds = dict(checkpoint_callback=checkpointCallback,
                callbacks=cbs,
                logger=wandblogger)

    trainer = instantiate(cfg.trainer, **kwds)

    # set up cassava image classification Task
    model = Task(cfg)

    trainer.fit(model)

    # Laod in the best checkpoint and save the model weights
    checkpointPath = checkpointCallback.best_model_path
    # Testing Stage
    _ = trainer.test(verbose=True, ckpt_path=checkpointPath)

    # load in the best model weights
    model = Task.load_from_checkpoint(checkpointPath)

    # create model save dir to save the weights of the
    # vanilla torch-model
    os.makedirs(cfg.general.save_dir, exist_ok=True)
    path = os.path.join(cfg.general.save_dir, f"{model_name}.pt")
    # save the weights of the model
    torch.save(model.model.state_dict(), f=path)
    # upload trained weights to wandb
    wandb.save(path)

    # save the original compiles config file to wandb
    conf_path = os.path.join(cfg.general.save_dir, "cfg.yml")
    OmegaConf.save(cfg, f=conf_path)
    wandb.save(conf_path)
示例#26
0
def main(cfg: DictConfig):
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)
    seed_everything(cfg.data.seed)

    # wandb
    wandb.init(project='VinBigData-Detection')
    wandb.config.update(dict(cfg.data))
    wandb.config.update(dict(cfg.train))
    wandb.config.update(dict(cfg.aug_kwargs_detection))
    wandb.config.update(dict(cfg.classification_kwargs))

    # omegaconf -> dict
    rep_aug_kwargs = OmegaConf.to_container(cfg.aug_kwargs_detection)

    class_name_dict = {
        0: 'Aortic enlargement',
        1: 'Atelectasis',
        2: 'Calcification',
        3: 'Cardiomegaly',
        4: 'Consolidation',
        5: 'ILD',
        6: 'Infiltration',
        7: 'Lung Opacity',
        8: 'Nodule/Mass',
        9: 'Other lesion',
        10: 'Pleural effusion',
        11: 'Pleural thickening',
        12: 'Pneumothorax',
        13: 'Pulmonary fibrosis',
    }

    # Setting  --------------------------------------------------
    data_dir = cfg.data.data_dir
    output_dir = cfg.data.output_dir
    img_size = cfg.data.img_size
    backbone = cfg.data.backbone
    use_class14 = cfg.data.use_class14

    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)

    if use_class14:
        class_name_dict.update({14: 'No finding'})

    # Register Dataset  --------------------------------------------------
    anno_df = pd.read_csv(os.path.join(data_dir, 'train_wbf_th0.7.csv'))

    if cfg.data.use_class14:
        pass
    else:
        anno_df = anno_df[anno_df['class_id'] != 14].reset_index(drop=True)

    # Extract rad id
    if cfg.data.rad_id != 'all':
        anno_df = anno_df[anno_df['rad_id'].isin(cfg.data.rad_id)].reset_index()

    if debug:
        anno_df = anno_df.head(100)

    # Split train, valid data - random
    if 'valid' in cfg.data.split_method:
        split_rate = float(cfg.data.split_method.split('_')[1]) / 100
        unique_image_ids = anno_df['image_id'].values
        unique_image_ids = np.random.RandomState(cfg.data.seed).permutation(unique_image_ids)
        train_image_ids = unique_image_ids[:int(len(unique_image_ids) * (1 - split_rate))]
        valid_image_ids = unique_image_ids[int(len(unique_image_ids) * (1 - split_rate)):]
        DatasetCatalog.register("xray_valid", lambda d='valid': get_xray_dict(anno_df, data_dir, cfg, valid_image_ids))
        MetadataCatalog.get("xray_valid").set(thing_classes=list(class_name_dict.values()))

    else:
        train_image_ids = anno_df['image_id'].values
    DatasetCatalog.register("xray_train", lambda d='train': get_xray_dict(anno_df, data_dir, cfg, train_image_ids))
    MetadataCatalog.get("xray_train").set(thing_classes=list(class_name_dict.values()))

    DatasetCatalog.register("xray_test", lambda d='test': get_test_xray_dict(data_dir))
    MetadataCatalog.get("xray_test").set(thing_classes=list(class_name_dict.values()))

    # Config  --------------------------------------------------
    detectron2_cfg = get_cfg()
    detectron2_cfg.aug_kwargs = CN(rep_aug_kwargs)
    detectron2_cfg.merge_from_file(model_zoo.get_config_file(backbone))
    detectron2_cfg.DATASETS.TRAIN = ("xray_train",)
    if 'valid' in cfg.data.split_method:
        detectron2_cfg.DATASETS.TEST = ("xray_valid",)
        detectron2_cfg.TEST.EVAL_PERIOD = cfg.train.max_iter // 10
    else:
        detectron2_cfg.DATASETS.TEST = ()
    detectron2_cfg.INPUT.MIN_SIZE_TRAIN = (img_size,)
    detectron2_cfg.INPUT.MAX_SIZE_TRAIN = img_size
    detectron2_cfg.DATALOADER.NUM_WORKERS = cfg.train.num_workers
    detectron2_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(backbone)
    detectron2_cfg.SOLVER.IMS_PER_BATCH = cfg.train.ims_per_batch
    detectron2_cfg.SOLVER.BASE_LR = cfg.train.lr
    detectron2_cfg.SOLVER.MAX_ITER = cfg.train.max_iter
    detectron2_cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
    detectron2_cfg.SOLVER.WARMUP_ITERS = 2000
    detectron2_cfg.SOLVER.CHECKPOINT_PERIOD = 200000
    detectron2_cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = cfg.train.batch_size_per_image
    detectron2_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 15 if use_class14 else 14
    detectron2_cfg.OUTPUT_DIR = output_dir
    detectron2_cfg.SEED = cfg.data.seed
    detectron2_cfg.PIXEL_MEAN = [103.530, 116.280, 123.675]
    detectron2_cfg.PIXEL_STD = [1.0, 1.0, 1.0]

    # Train  --------------------------------------------------
    os.makedirs(detectron2_cfg.OUTPUT_DIR, exist_ok=True)
    # trainer = DefaultTrainer(detectron2_cfg)
    trainer = MyTrainer(detectron2_cfg)
    trainer.resume_or_load(resume=True)
    trainer.train()

    # Rename Last Weight
    renamed_model = f"{backbone.split('.')[0].replace('/', '-')}.pth"
    os.rename(os.path.join(cfg.data.output_dir, 'model_final.pth'),
              os.path.join(cfg.data.output_dir, renamed_model))

    # Logging
    for model_path in glob.glob(os.path.join(cfg.data.output_dir, '*.pth')):
        wandb.save(model_path)

    # Inference Setting  ------------------------------------------------------
    detectron2_cfg = get_cfg()
    detectron2_cfg.merge_from_file(model_zoo.get_config_file(backbone))
    detectron2_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 15 if use_class14 else 14
    detectron2_cfg.MODEL.WEIGHTS = os.path.join(output_dir, renamed_model)  # path to the model we just trained
    detectron2_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = cfg.data.score_th   # set a custom testing threshold

    predictor = DefaultPredictor(detectron2_cfg)
    dataset_dicts = get_test_xray_dict(data_dir)

    # Visualize  ------------------------------------------------------
    target_image_ids = ['9a5094b2563a1ef3ff50dc5c7ff71345',
                        '22b8e616a61bbc4caaed0cf23b7159df',
                        '001d127bad87592efe45a5c7678f8b8d',
                        '008b3176a7248a0a189b5731ac8d2e95']

    for th in [0, 0.2, 0.5, 0.7]:
        visualize(target_image_ids, data_dir, output_dir, predictor, score_th=th)

    # Metrics
    if os.path.exists(os.path.join(output_dir, 'metrics.json')):
        metrics_df = pd.read_json(os.path.join(output_dir, 'metrics.json'), orient="records", lines=True)
        mdf = metrics_df.sort_values("iteration")

        mdf3 = mdf[~mdf["bbox/AP75"].isna()].reset_index(drop=True)
        for i in range(len(mdf3)):
            row = mdf3.iloc[i]
            wandb.log({'AP40': row["bbox/AP75"] / 100.})

        best_score = mdf3["bbox/AP75"].max() / 100.
        wandb.log({'Best-AP40-Score': best_score})

    # Inference  ------------------------------------------------------
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    sub = get_submission(dataset_dicts, cfg, predictor, device)

    now = datetime.datetime.now() + datetime.timedelta(hours=9)
    now = now.strftime("%Y%m%d-%H%M%S")

    filename = f'submission_{now}.csv'
    sub.to_csv(os.path.join('./submission', filename), index=False)
    wandb.save(os.path.join('./submission', filename))
    time.sleep(30)

    wandb.finish()
    DatasetCatalog.clear()
示例#27
0
文件: trainer.py 项目: realRaBot/mgn
    def train(self):
        training_mode = 'reinforce' if self.reinforce else 'seq2seq'
        print('| start training %s, running in directory %s' %
              (training_mode, self.run_dir))
        num_workers = self.opt.num_workers
        t = 0
        epoch = 0
        baseline = 0
        start = time.time()
        while t < self.num_iters:
            epoch += 1
            for x, y, ans, idx, g_data in self.train_loader:
                t += 1
                loss, reward = None, None
                self.model.set_input(x, y, g_data)
                self.optimizer.zero_grad()
                if self.reinforce:
                    pred = self.model.reinforce_forward()
                    reward = self.get_batch_reward(pred, ans, idx, 'train')
                    baseline = reward * (
                        1 - self.reward_decay) + baseline * self.reward_decay
                    advantage = reward - baseline
                    self.model.set_reward(advantage)
                    self.model.reinforce_backward(self.entropy_factor)
                else:
                    loss = self.model.supervised_forward()
                    self.model.supervised_backward()
                self.optimizer.step()

                if t % self.display_every == 0:
                    if self.reinforce:
                        self.stats['train_batch_accs'].append(reward)
                        self.log_stats('training batch reward', reward, t)
                        print('| iteration %d / %d, epoch %d, reward %f' %
                              (t, self.num_iters, epoch, reward))
                    else:
                        self.stats['train_losses'].append(loss)
                        self.log_stats('training batch loss', loss, t)
                        print('| iteration %d / %d, epoch %d, loss %f' %
                              (t, self.num_iters, epoch, loss))
                    self.stats['train_accs_ts'].append(t)

                if t % self.checkpoint_every == 0 or t >= self.num_iters:
                    print('| checking validation accuracy')
                    val_acc = self.check_val_accuracy()
                    print('| validation accuracy %f' % val_acc)
                    if val_acc >= self.stats['best_val_acc']:
                        print('| best model')
                        self.stats['best_val_acc'] = val_acc
                        self.stats['model_t'] = t
                        checkpoint_fp = f"{self.run_dir}/checkpoint_best.pt"
                        self.model.save_checkpoint(checkpoint_fp)
                        self.model.save_checkpoint(
                            '%s/checkpoint_iter%08d.pt' % (self.run_dir, t))
                        if self.visualize_training_wandb:
                            wandb.save(checkpoint_fp)
                    if not self.reinforce:
                        val_loss = self.check_val_loss()
                        print('| validation loss %f' % val_loss)
                        self.stats['val_losses'].append(val_loss)
                        self.log_stats('val loss', val_loss, t)
                    self.stats['val_accs'].append(val_acc)
                    self.log_stats('val accuracy', val_acc, t)
                    self.stats['val_accs_ts'].append(t)
                    # Save Checkpoint #
                    self.model.save_checkpoint('%s/checkpoint.pt' %
                                               self.run_dir)

                    checkpoint_dict = {
                        'args': self.opt.__dict__,
                        'stats': self.stats
                    }
                    json_fp = '%s/stats.json' % self.run_dir
                    #json_fp = os.path.join(self.run_dir, f'stats_{self.opt.run_timestamp}.json')
                    logging.info(f'Saving train_opt_[ts].json at: {json_fp}')
                    with open(json_fp, 'w') as f:
                        json.dump(self.stats, f, indent=2)
                        #json.dump(checkpoint_dict, f, indent=1)
                    if self.visualize_training_wandb:
                        wandb.save(json_fp)
                    self.log_params(t)

                if t >= self.num_iters:
                    break
        end = time.time()
        logging.info(
            f"Finished {self.num_iters} or {epoch}epochs in {end-start}s, with {num_workers}workers"
        )
def main(args):
    args.embedding = ''
    if args.method != 'distance':
        args.embedding = 'es{}'.format(args.emb_size)

    if args.method == 'node2vec':
        args.embedding += '_nw{}_wl{}_p{}_q{}'.format(args.num_walks,
                                                      args.walk_len, args.p,
                                                      args.q)

    if args.chr_tgt is None:
        args.chr_tgt = args.chr_src

    if args.full_coexpression or args.full_interactions:
        chrs_coexp = 'all'
    else:
        chrs_coexp = '{}_{}'.format(args.chr_src, args.chr_tgt)

    if args.full_interactions:
        chrs_interactions = 'all'
    else:
        chrs_interactions = '{}_{}'.format(args.chr_src, args.chr_tgt)

    # ToDo: add constraint that windows and thresholds have to have the same length
    hic_files = [
        'primary_{}_{}'.format(type, norm)
        for type, norm in zip(args.types, args.norms)
    ]

    hic_preprocessings = [
        '{}_{}{}'.format(window, threshold,
                         ('_' + str(weight)) if weight else '') for window,
        threshold, weight in zip(args.windows, args.hic_thrs, args.weights)
    ]

    args.interactions = [
        '{}_{}_{}'.format(file, chrs_interactions, preprocessing)
        for file, preprocessing in zip(hic_files, hic_preprocessings)
    ]

    interactions_no_chr = [
        '{}_{}'.format(file, preprocessing)
        for file, preprocessing in zip(hic_files, hic_preprocessings)
    ]
    args.interactions = '_'.join(args.interactions)

    args.aggregators = '_'.join(args.aggregators)
    coexp_thrs = '_'.join(args.coexp_thrs)

    if args.coexp_features:
        args.folder = 'coexpression_networks'
        args.name = 'chr_{}_{}'.format(chrs_coexp, coexp_thrs)
        experiment_id = '{}_{}_{}_{}_{}_{}_{}'.format(
            args.dataset, args.classifier, args.cv_splits, coexp_thrs,
            args.method, args.embedding, args.aggregators)
    else:
        args.folder = 'chromatin_networks'
        args.name = args.interactions
        experiment_id = '{}_{}_{}_{}_{}_{}_'.format(
            args.dataset, args.classifier, args.cv_splits, coexp_thrs,
            args.method, '_'.join(interactions_no_chr), args.embedding)
        experiment_id += '_' + str(args.aggregators)

    args.embedding = args.name + '_' + args.embedding

    id_hash = str(
        int(hashlib.sha1(experiment_id.encode()).hexdigest(), 16) % (10**8))

    os.makedirs('../../results/{}/chr_{}'.format(args.dataset, args.chr_src),
                exist_ok=True)
    os.makedirs('../../results/{}/predictions/chr_{}'.format(
        args.dataset, args.chr_src),
                exist_ok=True)

    if args.method == 'topological':
        if args.full_coexpression:
            filename = '{}chr_all/{}_{}_{}_{}.pkl'.format(
                'test/' if args.test else '', args.classifier, args.method,
                args.name, args.aggregators)
        else:
            filename = '{}chr_{}/{}_{}_{}_{}.pkl'.format(
                'test/' if args.test else '', args.chr_src, args.classifier,
                args.method, args.name, args.aggregators)
    else:
        if args.full_coexpression:
            filename = '{}chr_all/{}_{}_{}_{}_{}.pkl'.format(
                'test/' if args.test else '', args.classifier, args.method,
                args.embedding, args.aggregators, coexp_thrs)
        else:
            filename = '{}chr_{}/{}_{}_{}_{}_{}.pkl'.format(
                'test/' if args.test else '', args.chr_src, args.classifier,
                args.method, args.embedding, args.aggregators, coexp_thrs)

    if not os.path.exists('../../results/{}/{}'.format(
            args.dataset, filename)) or args.force:
        coexpression = np.load(
            '../../data/{}/coexpression_networks/coexpression_chr_{}_{}.npy'.
            format(args.dataset, chrs_coexp, coexp_thrs))
        chr_sizes = np.load('../../data/{}/chr_sizes.npy'.format(args.dataset))

        disconnected_nodes = np.load(
            '../../data/{}/disconnected_nodes/{}.npy'.format(
                args.dataset, args.name))

        start_src = None
        end_src = None

        if args.full_interactions and not args.full_coexpression:
            start_src = int(np.sum(chr_sizes[:args.chr_src]))
            end_src = int(start_src + chr_sizes[args.chr_src])

            start_tgt = int(np.sum(chr_sizes[:args.chr_tgt]))
            end_tgt = int(start_tgt + chr_sizes[args.chr_tgt])

            coexpression = coexpression[start_src:end_src, start_tgt:end_tgt]

            disconnected_nodes_src = disconnected_nodes[
                (disconnected_nodes >= start_src)
                & (disconnected_nodes < end_src)] - start_src
            disconnected_nodes_tgt = disconnected_nodes[
                (disconnected_nodes >= start_tgt)
                & (disconnected_nodes < end_tgt)] - start_tgt
        else:
            disconnected_nodes_src = disconnected_nodes
            disconnected_nodes_tgt = disconnected_nodes

        print("N. disconnected nodes:", len(disconnected_nodes_src))
        if len(disconnected_nodes) > 0:
            coexpression[disconnected_nodes_src] = 0
            coexpression[:, disconnected_nodes_tgt] = 0

        n_nodes = coexpression.shape[0]

        if args.full_coexpression:
            shapes = [
                np.load(
                    '../../data/{}/coexpression/coexpression_chr_{}_{}.npy'.
                    format(args.dataset, i, i)).shape for i in range(1, 23)
            ]

            mask = intra_mask(shapes)

            coexpression_intra = coexpression * mask
        else:
            coexpression_intra = coexpression
            mask = None

        edges_intra = np.array(np.argwhere(coexpression_intra == 1))
        edges_intra_nodes = np.unique(edges_intra)

        non_nodes_intra = np.setdiff1d(np.arange(n_nodes), edges_intra_nodes)

        coexpression_intra_neg = coexpression_intra.copy()
        coexpression_intra_neg[non_nodes_intra, :] = np.nan
        coexpression_intra_neg[:, non_nodes_intra] = np.nan

        non_edges_intra = np.array(np.argwhere(coexpression_intra_neg == 0))
        non_edges_intra = non_edges_intra[np.random.choice(
            non_edges_intra.shape[0], edges_intra.shape[0], replace=False)]

        edges = edges_intra
        non_edges = non_edges_intra

        if not args.coexp_intra:
            coexpression_inter = coexpression * np.logical_not(mask)

            edges_inter = np.array(np.argwhere(coexpression_inter == 1))

            if edges_intra.shape[0] > edges_inter.shape[0]:
                n_edges_inter = edges_inter.shape[0]
            else:
                n_edges_inter = int(edges_intra.shape[0] * args.inter_ratio)
            print('N. intra edges', edges_intra.shape[0], '- N. inter edges ',
                  edges_inter.shape[0], '->', n_edges_inter)
            edges_inter = edges_inter[np.random.choice(edges_inter.shape[0],
                                                       n_edges_inter,
                                                       replace=False)]
            edges_inter_nodes = np.unique(edges_inter)

            non_nodes_inter = np.setdiff1d(np.arange(n_nodes),
                                           edges_inter_nodes)

            coexpression_inter_neg = coexpression_inter.copy()
            coexpression_inter_neg[non_nodes_inter, :] = np.nan
            coexpression_inter_neg[:, non_nodes_inter] = np.nan

            non_edges_inter = np.array(
                np.argwhere(coexpression_inter_neg == 0))
            non_edges_inter = non_edges_inter[np.random.choice(
                non_edges_inter.shape[0], edges_inter.shape[0], replace=False)]

            edges = np.vstack((edges, edges_inter))
            non_edges = np.vstack((non_edges, non_edges_inter))

        n_edges = edges.shape[0]
        n_non_edges = non_edges.shape[0]

        if args.wandb:
            wandb.init(project="coexp-inference-models")
            wandb.config.update({
                'id':
                id_hash,
                'dataset':
                args.dataset,
                'fold':
                args.cv_splits,
                'windows':
                '_'.join(map(str, args.windows)),
                'chr src':
                args.chr_src,
                'chr tgt':
                args.chr_tgt,
                'hic thresholds':
                '_'.join(map(str, args.hic_thrs)),
                'coexp threshold':
                coexp_thrs,
                'full interactions':
                args.full_interactions,
                'full coexpression':
                args.full_coexpression,
                'embedding method':
                args.method,
                'aggregators':
                args.aggregators,
                'classifier':
                args.classifier,
                'interactions':
                args.interactions,
                'embeddings size':
                args.emb_size,
                'test':
                args.test
            })

        if args.method == 'topological':
            X = topological_features(args, edges, non_edges)
        elif args.method == 'ids':
            X = np.vstack((edges, non_edges))
        elif args.method == 'distance':
            X = distance_embedding(args.full_interactions, args.dataset,
                                   args.chr_src, edges, non_edges)
        else:
            X = method_embedding(args, n_nodes, edges, non_edges, start_src,
                                 end_src)
        y = np.hstack((np.ones(n_edges), np.zeros(n_non_edges)))
        X_train, X_test, y_train, y_test = train_test_split(X,
                                                            y,
                                                            test_size=0.2,
                                                            shuffle=True)

        results = defaultdict(list)
        if args.test:
            results = evaluate_embedding(X_train,
                                         y_train,
                                         args.classifier,
                                         verbose=1,
                                         clf_params={'n_estimators': 100},
                                         mask=mask,
                                         X_test=X_test,
                                         y_test=y_test)
            if args.wandb:
                wandb.run.summary.update({
                    'accuracy': results['acc'],
                    'accuracy std': results['acc'],
                    'precision': results['precision'],
                    'precision std': results['precision'],
                    'recall': results['recall'],
                    'recall std': results['recall']
                })
        else:
            for i in range(args.n_iter):
                results_iter = evaluate_embedding(
                    X_train,
                    y_train,
                    args.classifier,
                    verbose=1,
                    clf_params={'n_estimators': 100},
                    cv_splits=args.cv_splits,
                    mask=mask)
                for key in results_iter.keys():
                    results[key].extend(results_iter[key])

            if args.wandb:
                for _ in results.keys():
                    wandb.run.summary.update({
                        'accuracy':
                        np.mean(results['acc']),
                        'accuracy std':
                        np.std(results['acc']),
                        'precision':
                        np.mean(results['precision']),
                        'precision std':
                        np.std(results['precision']),
                        'recall':
                        np.mean(results['recall']),
                        'recall std':
                        np.std(results['recall'])
                    })

        with open('../../results/{}/{}'.format(args.dataset, filename),
                  'wb') as file_save:
            pickle.dump(results, file_save)

        if args.wandb:
            wandb.save('../../results/{}/{}'.format(args.dataset, filename))

        print("Mean Accuracy:", np.mean(results['acc']), "- Mean ROC:",
              np.mean(results['roc']), "- Mean F1:",
              np.mean(results['f1']), "- Mean Precision:",
              np.mean(results['precision']), "- Mean Recall",
              np.mean(results['recall']))
    else:
        print('Result already computed for {}. Skipped.'.format(filename))
示例#29
0
def train(generator,
          discriminator,
          autoencoder,
          g_optim,
          d_optim,
          step,
          iteration=0,
          startpoint=0,
          used_sample=0,
          d_losses=[],
          g_losses=[],
          alpha=0):

    std = 0.2

    resolution = (25 * 2**step, 8 * 2**step)

    origin_loader = gain_sample(batch_size[step], resolution)
    data_loader = iter(origin_loader)

    # ae_step = 4
    # ae_resolution = (25 * 2 ** ae_step, 8 * 2 ** ae_step)
    # fd_calculator = DomainFD(autoencoder, ae_resolution, device=device)

    # ae_data_loader = iter(gain_sample(359, ae_resolution))
    # fd_calculator.fit_real_data(ae_data_loader)

    reset_LR(g_optim, args.lr)
    reset_LR(d_optim, args.lr)

    progress_bar = tqdm(total=n_sample_total, initial=used_sample)
    # Train
    while used_sample < n_sample_total:

        # done 800 x 256 step
        if used_sample > 1_000_000:
            std = args.instance_noise - (
                used_sample - 1_000_000) * args.instance_noise / 600_000
            std = max(std, 0)

        iteration += 1
        # alpha = min(1, alpha + batch_size[step] / (args.n_sample))
        alpha = 1

        #
        # if (used_sample - startpoint) > args.n_sample and step < max_step:
        #     step += 1
        #     print("Now on step", step)
        #     alpha = 0
        #     startpoint = used_sample
        #
        #     resolution = (25 * 2 ** step, 8 * 2 ** step)
        #
        #     # Avoid possible memory leak
        #     del origin_loader
        #
        #     # Change batch size
        #     # apply resizing
        #     origin_loader = gain_sample(batch_size[step], resolution)
        #
        #     data_loader = iter(origin_loader)
        #
        #     reset_LR(g_optim, args.lr)
        #     reset_LR(d_optim, args.lr)

        # D Update
        real_image = sample_data(data_loader, origin_loader)
        real_image.requires_grad = True

        # Count used sample
        used_sample += real_image.shape[0]
        progress_bar.update(real_image.shape[0])

        # Send image to GPU
        real_image = real_image.to(device)

        # Spectral Regularization after gradient step
        if args.spectral_reg:
            discriminator.update_sr(step, alpha)

        real_predict = discriminate(discriminator, real_image, step, alpha,
                                    std, args.n_gpu)
        real_loss = nn.functional.softplus(-real_predict).mean()

        fake_image = generate(generator, step, alpha, random_mix_steps(),
                              resolution, args.n_gpu)
        fake_predict = discriminate(discriminator, fake_image, step, alpha,
                                    std, args.n_gpu)

        fake_loss = nn.functional.softplus(fake_predict).mean()

        d_loss = (real_loss + fake_loss).item()

        grad_real = torch.autograd.grad(outputs=real_loss.sum(),
                                        inputs=real_image,
                                        create_graph=True)[0]
        grad_penalty_real = (grad_real.view(grad_real.size(0),
                                            -1).norm(2, dim=1)**2).mean()
        grad_penalty_real = 10 / 2 * grad_penalty_real

        d_optim.zero_grad()
        real_loss.backward(retain_graph=True)
        fake_loss.backward()
        grad_penalty_real.backward()
        d_optim.step()

        del real_image, real_predict, real_loss, fake_image, fake_predict, fake_loss, grad_penalty_real

        # G Update

        for i in range(args.g_steps):
            fake_image = generate(generator, step, alpha, random_mix_steps(),
                                  resolution, args.n_gpu)
            fake_predict = discriminate(discriminator, fake_image, step, alpha,
                                        std, args.n_gpu)
            fake_loss = nn.functional.softplus(-fake_predict).mean()

            g_optim.zero_grad()
            fake_loss.backward()
            g_optim.step()

        if iteration % n_show_loss == 0:
            g_losses.append(fake_loss.item())
            d_losses.append(d_loss)
            # print(fd_calculator.calculate_fd(fake_image))
            # fd.append(fd_calculator.calculate_fd(fake_image))

            wandb.log(
                {
                    "G Loss": g_losses[-1],
                    "D Loss": d_losses[-1],
                    # "Domain FD": fd[-1],
                    "Images Shown": used_sample
                },
                step=iteration)

        if iteration % n_save_im == 0:
            imsave(fake_image.data.cpu(), iteration)

        del fake_image, fake_loss

        if iteration % n_checkpoint == 0:
            # Save the model every 50 iterations
            torch.save(
                {
                    'generator': generator.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'g_optim': g_optim.state_dict(),
                    'd_optim': d_optim.state_dict(),
                    'parameters':
                    (step, iteration, startpoint, used_sample, alpha),
                    'd_losses': d_losses,
                    'g_losses': g_losses,
                },
                f'{args.save_checkpoints + args.run_name}/trained-{iteration}.pth'
            )
            wandb.save(
                f'{args.save_checkpoints + args.run_name}/trained-{iteration}.pth'
            )
            print(f' Model successfully saved.')

        progress_bar.set_description((
            f'Resolution: {resolution[0]}*{resolution[1]}  D_Loss: {d_losses[-1]:.4f}  G_Loss: {g_losses[-1]:.4f}  Alpha: {alpha:.4f}'
        ))
示例#30
0
def train(
    run_name: str,
    # Data
    train_filepath: str = CSNJS_TRAIN_FILEPATH,
    eval_filepath: str = CSNJS_VALID_FILEPATH,
    spm_filepath: str = SPM_UNIGRAM_FILEPATH,
    program_mode="identity",
    eval_program_mode="identity",
    label_mode="identifier",
    num_workers=1,
    limit_dataset_size=-1,
    # Model
    model_type="transformer",
    n_decoder_layers=4,
    d_model: int = 512,
    resume_path: str = "",
    resume_encoder_name: str = "encoder_q",  # encoder_q, encoder_k, encoder
    resume_project: bool = False,
    # Optimization
    train_decoder_only: bool = False,
    num_epochs: int = 50,
    save_every: int = 2,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.98,
    use_lr_warmup: bool = True,
    loss_type = "nll_token",  # nll_token or nll_sequence
    # Loss
    subword_regularization_alpha: float = 0,
    # Computational
    use_cuda: bool = True,
    auto_test: bool = True,
    seed: int = 0,
):
    """Train model"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / run_name
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    config = locals()
    logger.info(f"Config: {config}")
    wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code")

    if use_cuda:
        assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False"

    train_augmentations = [
        {"fn": "sample_lines", "line_length_pct": 0.5},
        {"fn": "insert_var_declaration", "prob": 0.5},
        {"fn": "rename_variable", "prob": 0.5},
    ]
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")

    # Create training dataset and dataloader
    logger.info(f"Training data path {train_filepath}")
    train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Training dataset size: {len(train_dataset)}")
    train_loader = javascript_dataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        augmentations=train_augmentations,
        sp=sp,
        program_mode=program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create eval dataset and dataloader
    logger.info(f"Eval data path {eval_filepath}")
    eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    eval_loader = javascript_dataloader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        augmentations=[],
        sp=sp,
        program_mode=eval_program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create model
    pad_id = sp.PieceToId("[PAD]")
    if model_type == "transformer":
        model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model)
        logger.info(f"Created TransformerModel with {count_parameters(model)} params")
    elif model_type == "lstm":
        model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model)
        logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params")

    # Load checkpoint
    if resume_path:
        logger.info(f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}")
        checkpoint = torch.load(resume_path)
        pretrained_state_dict = checkpoint["model_state_dict"]
        encoder_state_dict = {}
        assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"]

        for key, value in pretrained_state_dict.items():
            if key.startswith(resume_encoder_name + ".") and "project_layer" not in key:
                remapped_key = key[len(resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
            if key.startswith(resume_encoder_name + ".") and "project_layer.0." in key and resume_project:
                remapped_key = key[len(resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
        model.encoder.load_state_dict(encoder_state_dict, strict=False)
        logger.info(f"Loaded state dict from {resume_path}")
        logger.info(f"Loaded keys: {encoder_state_dict.keys()}")

    # Set up optimizer
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    wandb.watch(model, log="all")
    params = model.module.decoder.parameters() if train_decoder_only else model.parameters()
    optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9)
    if use_lr_warmup:
        scheduler = get_linear_schedule_with_warmup(optimizer, 5000, len(train_loader) * num_epochs)
    else:
        scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0)

    global_step = 0
    min_eval_loss = float("inf")
    for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        if train_decoder_only:
            model.module.encoder.eval()
            model.module.decoder.train()
        else:
            model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for X, Y, X_lengths, Y_lengths in pbar:
            if use_cuda:
                X = X.cuda()
                Y = Y.cuda()
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            optimizer.zero_grad()
            # NOTE: X and Y are [B, max_seq_len] tensors (batch first)
            logits = model(X, Y[:, :-1], X_lengths, Y_lengths)
            if loss_type == "nll_sequence":
                loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum')
                loss = loss / X.size(0)  # Average over num sequences, not target sequence lengths
                                        # Thus, minimize bits per sequence.
            elif loss_type == "nll_token":
                loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id,)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Log loss
            global_step += 1
            wandb.log(
                {"epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]}, step=global_step
            )
            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")

        # Evaluate
        logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...")
        max_decode_len = 20 if label_mode == "identifier" else 200
        eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type)
        logger.info(f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}")
        wandb.log({"epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss}, step=global_step)

        # Save checkpoint
        if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss:
            checkpoint = {
                "model_state_dict": model.module.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
                "config": config,
                "eval_loss": eval_loss,
            }
            if eval_loss < min_eval_loss:
                logger.info(f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}")
                min_eval_loss = eval_loss
                model_file = run_dir / "ckpt_best.pth"
            else:
                model_file = run_dir / f"ckpt_ep{epoch:04d}.pth"
            logger.info(f"Saving checkpoint to {model_file}...")
            torch.save(checkpoint, str(model_file.resolve()))
            wandb.save(str(model_file.resolve()))
            logger.info("Done.")

    if auto_test:
        best_ckpt = run_dir / "ckpt_best.pth"
        test(
            str(best_ckpt.resolve()),
            CSNJS_TEST_FILEPATH,
            spm_filepath,
            program_mode,
            label_mode,
            num_workers,
            -1,
            n_decoder_layers=n_decoder_layers,
        )