Exemple #1
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.no_tqdm = config.no_tqdm
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        if config.uncertainty == True:
            self.model_name += '_uncertainty_1'
        else:
            self.model_name += '_uncertainty_0'
        if config.intrinsic == True:
            self.model_name += '_intrinsic_1'
        else:
            self.model_name += '_intrinsic_0'

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(self.patch_size, self.num_patches,
                                        self.glimpse_scale, self.num_channels,
                                        self.loc_hidden, self.glimpse_hidden,
                                        self.std, self.hidden_size,
                                        self.num_classes, self.config)
        if self.use_gpu:
            self.model.cuda()

        self.dtypeFloat = (torch.cuda.FloatTensor
                           if self.use_gpu else torch.FloatTensor)
        self.dtypeLong = (torch.cuda.LongTensor
                          if self.use_gpu else torch.LongTensor)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.config.init_lr,
        )
        lambda_of_lr = lambda epoch: 0.95**epoch
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_of_lr)
Exemple #2
0
def test_a2c(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low),
          np.max(env.action_space.high))
    # train_envs = gym.make(args.task)
    train_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)],
        norm_obs=True)
    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)],
        norm_obs=True,
        obs_rms=train_envs.obs_rms,
        update_obs_rms=False)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net_a = Net(args.state_shape,
                hidden_sizes=args.hidden_sizes,
                activation=nn.Tanh,
                device=args.device)
    actor = ActorProb(net_a,
                      args.action_shape,
                      max_action=args.max_action,
                      unbounded=True,
                      device=args.device).to(args.device)
    net_c = Net(args.state_shape,
                hidden_sizes=args.hidden_sizes,
                activation=nn.Tanh,
                device=args.device)
    critic = Critic(net_c, device=args.device).to(args.device)
    torch.nn.init.constant_(actor.sigma_param, -0.5)
    for m in list(actor.modules()) + list(critic.modules()):
        if isinstance(m, torch.nn.Linear):
            # orthogonal initialization
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)
    # do last policy layer scaling, this will make initial actions have (close to)
    # 0 mean and std, and will help boost performances,
    # see https://arxiv.org/abs/2006.05990, Fig.24 for details
    for m in actor.mu.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.zeros_(m.bias)
            m.weight.data.copy_(0.01 * m.weight.data)

    optim = torch.optim.RMSprop(set(actor.parameters()).union(
        critic.parameters()),
                                lr=args.lr,
                                eps=1e-5,
                                alpha=0.99)

    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

    def dist(*logits):
        return Independent(Normal(*logits), 1)

    policy = A2CPolicy(actor,
                       critic,
                       optim,
                       dist,
                       discount_factor=args.gamma,
                       gae_lambda=args.gae_lambda,
                       max_grad_norm=args.max_grad_norm,
                       vf_coef=args.vf_coef,
                       ent_coef=args.ent_coef,
                       reward_normalization=args.rew_norm,
                       action_scaling=True,
                       action_bound_method=args.bound_action_method,
                       lr_scheduler=lr_scheduler,
                       action_space=env.action_space)

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs)
    # log
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_a2c'
    log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = BasicLogger(writer, update_interval=100, train_interval=100)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    if not args.watch:
        # trainer
        result = onpolicy_trainer(policy,
                                  train_collector,
                                  test_collector,
                                  args.epoch,
                                  args.step_per_epoch,
                                  args.repeat_per_collect,
                                  args.test_num,
                                  args.batch_size,
                                  step_per_collect=args.step_per_collect,
                                  save_fn=save_fn,
                                  logger=logger,
                                  test_in_train=False)
        pprint.pprint(result)

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num,
                                    render=args.render)
    print(
        f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}'
    )
Exemple #3
0
def train(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    cudnn.benchmark = True

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(32),
        transforms.RandomHorizontalFlip(0.5),
        get_color_distortion(s=0.5)  # normal augmentation used in SimCLR
    ])

    # Prepare dataset
    data_dir = hydra.utils.to_absolute_path(
        args.data_dir)  # get absolute path of data dir
    if args.aug:
        aug_list = augmentations.augmentations_all if args.all_ops else augmentations.augmentations
        train_set = AugPairDataset(dataset=args.dataset,
                                   root=data_dir,
                                   transform=train_transform,
                                   download=True,
                                   aug_list=aug_list,
                                   width=args.mixture_width,
                                   depth=args.mixture_depth,
                                   aug_severity=args.aug_severity)
    else:
        train_transform = transforms.Compose(
            [train_transform, transforms.ToTensor()])
        train_set = PairDataset(
            dataset=args.dataset,
            root=data_dir,
            transform=train_transform,
            download=True,
        )

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              drop_last=True)

    # Prepare model
    assert args.backbone in ['resnet18', 'resnet34']
    base_encoder = eval(args.backbone)
    model = SimCLR(base_encoder, projection_dim=args.projection_dim).cuda()
    logger.info('Base model: {}'.format(args.backbone))
    logger.info('feature dim: {}, projection dim: {}'.format(
        model.feature_dim, args.projection_dim))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # cosine annealing lr
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
            step,
            args.epochs * len(train_loader),
            args.learning_rate,  # lr_lambda computes multiplicative factor
            1e-3))

    # SimCLR training
    model.train()
    for epoch in range(1, args.epochs + 1):
        loss_meter = AverageMeter("Loss")
        train_bar = tqdm(train_loader)
        for x, y in train_bar:
            sizes = x.size()
            x = x.view(sizes[0] * 2, sizes[2], sizes[3],
                       sizes[4]).cuda(non_blocking=True)

            optimizer.zero_grad()
            feature, rep = model(x)
            loss = nt_xent(rep, args.temperature)
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_meter.update(loss.item(), x.size(0))
            train_bar.set_description(
                "Train epoch {}, SimCLR loss: {:.4f}".format(
                    epoch, loss_meter.avg))

        # save checkpoint very log_interval epochs
        if epoch >= args.log_interval and epoch % args.log_interval == 0:
            logger.info(
                "==> Save checkpoint. Train epoch {}, SimCLR loss: {:.4f}".
                format(epoch, loss_meter.avg))
            torch.save(
                model.state_dict(),
                'simclr_{}_epoch{}{}.pt'.format(args.backbone, epoch,
                                                '_aug' if args.aug else ''))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch-size", type=int, default=512, metavar="N")
    parser.add_argument('--test-batch-size', type=int, default=512, metavar='N')
    parser.add_argument("--epochs", type=int, default=24, metavar="N")
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR')
    parser.add_argument("--no-cuda", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=1, metavar="S")
    parser.add_argument("--save-model", action="store_true", default=False)
    parser.add_argument("--data", default="./data/")
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        torch.backends.cudnn.benchmark = True

    cifar_mean = np.array([125.31, 122.95, 113.87], dtype=np.float32)
    cifar_std = np.array([62.99, 62.09, 66.70], dtype=np.float32)

    train_set = CIFAR10(root=args.data, train=True, download=True)
    data = train_set.data
    data = np.pad(data, [(0, 0), (4, 4), (4, 4), (0, 0)], mode="reflect")
    data = normalize(data, mean=cifar_mean, std=cifar_std)
    data = data.transpose([0, 3, 1, 2])
    train_set = list(zip(data, train_set.targets))
    train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]
    train_set = RandomAugmentation(train_set, train_transforms)

    test_set = CIFAR10(root=args.data, train=False, download=False)
    data = test_set.data
    data = normalize(data, mean=cifar_mean, std=cifar_std)
    data = data.transpose([0, 3, 1, 2])
    test_set = list(zip(data, test_set.targets))

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        num_workers=0,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.test_batch_size,
        num_workers=0,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
    )

    model = net().to(device).half()

    optimizer = SGD(
        model.parameters(),
        lr=args.lr / args.batch_size,
        momentum=0.9,
        weight_decay=5e-4 * args.batch_size,
        nesterov=True,
    )

    def piecewise_linear(step):
        epoch = (step + 1) / len(train_loader)
        lr = np.interp([epoch], [0, 5, args.epochs], [0, 0.4, 0])[0]
        return lr

    scheduler = LambdaLR(optimizer, piecewise_linear, last_epoch=-1)

    for epoch in range(1, args.epochs + 1):
        train_set.sample_transformations()
        train_summary = train(
            args, model, device, train_loader, optimizer, scheduler, epoch
        )
        test_summary = test(model, device, test_loader)
        print(
            f"{epoch:3d}. epoch:"
            f" train accuracy = {train_summary['accuracy']:.4f} ({train_summary['loss']:.2e})"
            f" test accuracy = {test_summary['accuracy']:.4f} ({test_summary['loss']:.2e})"
        )

    if args.save_model:
        torch.save(model.state_dict(), "model.pt")
Exemple #5
0
 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
              weight_decay=0, amsgrad=False, alpha=0.0):
     super(ScheduledAdam, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
     self.scheduler = LambdaLR(self, lambda t : 1/(t+1)**alpha)
Exemple #6
0
    def get_scheduler(self, optimizer):
        gamma = self.cfg.scheduler.options.gamma
        step_size = self.cfg.scheduler.options.step_size

        return LambdaLR(optimizer,
                        lr_lambda=lambda step: gamma**(step / step_size))
Exemple #7
0
def get_optim_scheduler(optimizer):
    return LambdaLR(optimizer, lr_lambda=lr_lambda_fun)
Exemple #8
0
def main(args):
    device = parse_device(args)
    DATASET = CIFAR10
    args.jobname = "train_resnet_{}".format(DATASET.__name__)
    dirname = log_dir(args.jobname)  ###
    net_pt = "{}/net.pt".format(dirname)

    net = ResNet(num_classes=10)
    args.net = "ResNet"
    write_config(dirname)
    temp_dict = read_config(dirname)  ## for testing read & write

    net.to(device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_set = DATASET(root='./data/cifar10',
                        train=True,
                        download=True,
                        transform=transform,
                        target_transform=None)
    test_set = DATASET(root='./data/cifar10',
                       train=False,
                       download=True,
                       transform=transform,
                       target_transform=None)

    num_workers = 4
    train_loader = DataLoader(train_set,
                              batch_size=100,
                              shuffle=True,
                              num_workers=num_workers,
                              drop_last=False)

    test_loader = DataLoader(test_set,
                             batch_size=500,
                             shuffle=False,
                             num_workers=num_workers,
                             drop_last=False)

    test(test_loader, net, device=device, log_dir=dirname, debug=True)

    ### pretrain
    print("Pretraining...")
    num_epoch = args.epoch
    lr = args.lr
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    #optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    lambda1 = lambda epoch: epoch // 20
    lambda2 = lambda epoch: 0.95**epoch
    scheduler = LambdaLR(optimizer, lambda2)

    for _ in tqdm(range(num_epoch)):
        #print("epoch:{}".format(_))
        train(train_loader, net, optimizer, device=device, log_dir=dirname)
        scheduler.step()
        test(test_loader, net, device=device, log_dir=dirname, debug=True)

    torch.save(net.state_dict(), net_pt)
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help="The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_base_6layer_6conect.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--train_iter_multiplier",
        default=1.0,
        type=float,
        help="multiplier for the multi-task training.",
    )
    parser.add_argument(
        "--train_iter_gap",
        default=4,
        type=int,
        help="forward every n iteration is the validation score is not improving over the last 3 epoch, -1 means will stop",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--in_memory",
        default=False,
        type=bool,
        help="whether use chunck for parallel training.",
    )
    parser.add_argument(
        "--optim", default="AdamW", type=str, help="what to use for the optimization."
    )
    parser.add_argument(
        "--tasks", default="", type=str, help="1-2-3... training task separate by -"
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--vision_scratch",
        action="store_true",
        help="whether pre-trained the image or not.",
    )
    parser.add_argument(
        "--evaluation_interval", default=1, type=int, help="evaluate very n epoch."
    )
    parser.add_argument(
        "--lr_scheduler",
        default="mannul",
        type=str,
        help="whether use learning rate scheduler.",
    )
    parser.add_argument(
        "--baseline", action="store_true", help="whether use single stream baseline."
    )
    parser.add_argument(
        "--resume_file", default="", type=str, help="Resume from checkpoint"
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--clean_train_sets",
        default=True,
        type=bool,
        help="whether clean train sets for multitask data.",
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )
    parser.add_argument(
        "--task_specific_tokens",
        action="store_true",
        help="whether to use task specific tokens for the multi-task learning.",
    )

    args = parser.parse_args()
    with open("vilbert_tasks.yml", "r") as f:
        task_cfg = edict(yaml.safe_load(f))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.baseline:
        from pytorch_transformers.modeling_bert import BertConfig
        from vilbert.basebert import BaseBertForVLTasks
    else:
        from vilbert.vilbert import BertConfig
        from vilbert.vilbert import VILBertForVLTasks

    task_names = []
    task_lr = []
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        name = task_cfg[task]["name"]
        task_names.append(name)
        task_lr.append(task_cfg[task]["lr"])

    base_lr = min(task_lr)
    loss_scale = {}
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        loss_scale[task] = task_lr[i] / base_lr

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timeStamp = (
        "-".join(task_names)
        + "_"
        + args.config_file.split("/")[1].split(".")[0]
        + prefix
    )
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r")
    )

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    task_batch_size, task_num_iters, task_ids, task_datasets_train, task_datasets_val, task_dataloader_train, task_dataloader_val = LoadDatasets(
        args, task_cfg, args.tasks.split("-")
    )

    logdir = os.path.join(savePath, "logs")
    tbLogger = utils.tbLogger(
        logdir,
        savePath,
        task_names,
        task_ids,
        task_num_iters,
        args.gradient_accumulation_steps,
    )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if args.task_specific_tokens:
        config.task_specific_tokens = True

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_ave_iter = {}
    task_stop_controller = {}
    for task_id, num_iter in task_num_iters.items():
        task_ave_iter[task_id] = int(
            task_cfg[task]["num_epoch"]
            * num_iter
            * args.train_iter_multiplier
            / args.num_train_epochs
        )
        task_stop_controller[task_id] = utils.MultiTaskStopOnPlateau(
            mode="max",
            patience=1,
            continue_threshold=0.005,
            cooldown=1,
            threshold=0.001,
        )

    task_ave_iter_list = sorted(task_ave_iter.values())
    median_num_iter = task_ave_iter_list[-1]
    num_train_optimization_steps = (
        median_num_iter * args.num_train_epochs // args.gradient_accumulation_steps
    )
    num_labels = max([dataset.num_labels for dataset in task_datasets_train.values()])

    if args.dynamic_attention:
        config.dynamic_attention = True
    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.baseline:
        model = BaseBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )
    else:
        model = VILBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )

    task_losses = LoadLosses(args, task_cfg, args.tasks.split("-"))

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = base_lr
                    else:
                        lr = 1e-4
                else:
                    lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.0}
                ]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.01}
                ]

    if default_gpu:
        print(len(list(model.named_parameters())), len(optimizer_grouped_parameters))

    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters, lr=base_lr, correct_bias=False)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters, lr=base_lr)

    warmpu_steps = args.warmup_proportion * num_train_optimization_steps

    if args.lr_scheduler == "warmup_linear":
        warmup_scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=warmpu_steps, t_total=num_train_optimization_steps
        )
    else:
        warmup_scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmpu_steps)

    lr_reduce_list = np.array([5, 7])
    if args.lr_scheduler == "automatic":
        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.2, patience=1, cooldown=1, threshold=0.001
        )
    elif args.lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingLR(
            optimizer, T_max=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "cosine_warm":
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "mannul":

        def lr_lambda_fun(epoch):
            return pow(0.2, np.sum(lr_reduce_list <= epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    startIterID = 0
    global_step = 0
    start_epoch = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace("module.", "", 1)] = checkpoint[
                    "model_state_dict"
                ][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        warmup_scheduler.load_state_dict(checkpoint["warmup_scheduler_state_dict"])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        start_epoch = int(checkpoint["epoch_id"]) + 1
        task_stop_controller = checkpoint["task_stop_controller"]
        tbLogger = checkpoint["tb_logger"]
        del checkpoint

    model.to(device)

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", task_batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    task_iter_train = {name: None for name in task_ids}
    task_count = {name: 0 for name in task_ids}
    for epochId in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"):
        model.train()
        for step in range(median_num_iter):
            iterId = startIterID + step + (epochId * median_num_iter)
            first_task = True
            for task_id in task_ids:
                is_forward = False
                if (not task_stop_controller[task_id].in_stop) or (
                    iterId % args.train_iter_gap == 0
                ):
                    is_forward = True

                if is_forward:
                    loss, score = ForwardModelsTrain(
                        args,
                        task_cfg,
                        device,
                        task_id,
                        task_count,
                        task_iter_train,
                        task_dataloader_train,
                        model,
                        task_losses,
                    )

                    loss = loss * loss_scale[task_id]
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        if args.fp16:
                            lr_this_step = args.learning_rate * warmup_linear(
                                global_step / num_train_optimization_steps,
                                args.warmup_proportion,
                            )
                            for param_group in optimizer.param_groups:
                                param_group["lr"] = lr_this_step

                        if first_task and (
                            global_step < warmpu_steps
                            or args.lr_scheduler == "warmup_linear"
                        ):
                            warmup_scheduler.step()

                        optimizer.step()
                        model.zero_grad()
                        if first_task:
                            global_step += 1
                            first_task = False

                        if default_gpu:
                            tbLogger.step_train(
                                epochId,
                                iterId,
                                float(loss),
                                float(score),
                                optimizer.param_groups[0]["lr"],
                                task_id,
                                "train",
                            )

            if "cosine" in args.lr_scheduler and global_step > warmpu_steps:
                lr_scheduler.step()

            if (
                step % (20 * args.gradient_accumulation_steps) == 0
                and step != 0
                and default_gpu
            ):
                tbLogger.showLossTrain()

            # decided whether to evaluate on each tasks.
            for task_id in task_ids:
                if (iterId != 0 and iterId % task_num_iters[task_id] == 0) or (
                    epochId == args.num_train_epochs - 1 and step == median_num_iter - 1
                ):
                    evaluate(
                        args,
                        task_dataloader_val,
                        task_stop_controller,
                        task_cfg,
                        device,
                        task_id,
                        model,
                        task_losses,
                        epochId,
                        default_gpu,
                        tbLogger,
                    )

        if args.lr_scheduler == "automatic":
            lr_scheduler.step(sum(val_scores.values()))
            logger.info("best average score is %3f" % lr_scheduler.best)
        elif args.lr_scheduler == "mannul":
            lr_scheduler.step()

        if epochId in lr_reduce_list:
            for task_id in task_ids:
                # reset the task_stop_controller once the lr drop
                task_stop_controller[task_id]._reset()

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin"
            )
            output_checkpoint = os.path.join(savePath, "pytorch_ckpt_latest.tar")
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "warmup_scheduler_state_dict": warmup_scheduler.state_dict(),
                    # 'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    "global_step": global_step,
                    "epoch_id": epochId,
                    "task_stop_controller": task_stop_controller,
                    "tb_logger": tbLogger,
                },
                output_checkpoint,
            )
    tbLogger.txt_close()
Exemple #10
0
def train_mlt_single(args):
    global logger
    logger.info(args)
    task_lst, vocabs = utils.get_data(args.data_path)
    task_db = task_lst[args.task_id]
    train_data = task_db.train_set
    dev_data = task_db.dev_set
    test_data = task_db.test_set
    task_name = task_db.task_name

    if args.debug:
        train_data = train_data[:200]
        dev_data = dev_data[:200]
        test_data = test_data[:200]
        args.epochs = 3
        args.pruning_iter = 3

    summary_writer = SummaryWriter(
        log_dir=os.path.join(args.tb_path, "global/%s" % task_name)
    )

    logger.info("task name: {}, task id: {}".format(task_db.task_name, task_db.task_id))
    logger.info(
        "train len {}, dev len {}, test len {}".format(
            len(train_data), len(dev_data), len(test_data)
        )
    )

    # init model
    model = get_model(args, task_lst, vocabs)

    logger.info("model: \n{}".format(model))
    if args.init_weights is not None:
        utils.load_model(model, args.init_weights)

    if utils.need_acc(task_name):
        metrics = [AccuracyMetric(target="y"), MetricInForward(val_name="loss")]
        metric_key = "acc"

    else:
        metrics = [
            YangJieSpanMetric(
                tag_vocab=vocabs[task_name],
                pred="pred",
                target="y",
                seq_len="seq_len",
                encoding_type="bioes" if task_name == "ner" else "bio",
            ),
            MetricInForward(val_name="loss"),
        ]
        metric_key = "f"
    logger.info(metrics)

    need_cut_names = list(set([s.strip() for s in args.need_cut.split(",")]))
    prune_names = []
    for name, p in model.named_parameters():
        if not p.requires_grad or "bias" in name:
            continue
        for n in need_cut_names:
            if n in name:
                prune_names.append(name)
                break

    # get Pruning class
    pruner = Pruning(
        model, prune_names, final_rate=args.final_rate, pruning_iter=args.pruning_iter
    )
    if args.init_masks is not None:
        pruner.load(args.init_masks)
        pruner.apply_mask(pruner.remain_mask, pruner._model)
    # save checkpoint
    os.makedirs(args.save_path, exist_ok=True)

    logger.info('Saving init-weights to {}'.format(args.save_path))
    torch.save(
        model.cpu().state_dict(), os.path.join(args.save_path, "init_weights.th")
    )
    torch.save(args, os.path.join(args.save_path, "args.th"))
    # start training and pruning
    summary_writer.add_scalar("remain_rate", 100.0, 0)
    summary_writer.add_scalar("cutoff", 0.0, 0)

    if args.init_weights is not None:
        init_tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = init_tester.test()
        logger.info("No init testing, Result: {}".format(res))
        del res, init_tester

    for prune_step in range(pruner.pruning_iter + 1):
        # reset optimizer every time
        optim_params = [p for p in model.parameters() if p.requires_grad]
        # utils.get_logger(__name__).debug(optim_params)
        utils.get_logger(__name__).debug(len(optim_params))
        optimizer = get_optim(args.optim, optim_params)
        # optimizer = TriOptim(optimizer, args.n_filters, args.warmup, args.decay)
        factor = pruner.cur_rate / 100.0
        factor = 1.0
        # print(factor, pruner.cur_rate)
        for pg in optimizer.param_groups:
            pg["lr"] = factor * pg["lr"]
        utils.get_logger(__name__).info(optimizer)

        trainer = Trainer(
            train_data,
            model,
            loss=LossInForward(),
            optimizer=optimizer,
            metric_key=metric_key,
            metrics=metrics,
            print_every=200,
            batch_size=args.batch_size,
            num_workers=4,
            n_epochs=args.epochs,
            dev_data=dev_data,
            save_path=None,
            sampler=fastNLP.BucketSampler(batch_size=args.batch_size),
            callbacks=[
                pruner,
                # LRStep(lstm.WarmupLinearSchedule(optimizer, args.warmup, int(len(train_data)/args.batch_size*args.epochs)))
                GradientClipCallback(clip_type="norm", clip_value=5),
                LRScheduler(
                    lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05 * ep))
                ),
                LogCallback(path=os.path.join(args.tb_path, "No", str(prune_step))),
            ],
            use_tqdm=False,
            device="cuda",
            check_code_level=-1,
        )
        res = trainer.train()
        logger.info("No #{} training, Result: {}".format(pruner.prune_times, res))
        name, val = get_metric(res)
        summary_writer.add_scalar("prunning_dev_acc", val, prune_step)
        tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = tester.test()
        logger.info("No #{} testing, Result: {}".format(pruner.prune_times, res))
        name, val = get_metric(res)
        summary_writer.add_scalar("pruning_test_acc", val, prune_step)

        # prune and save
        torch.save(
            model.state_dict(),
            os.path.join(
                args.save_path,
                "best_{}_{}.th".format(pruner.prune_times, pruner.cur_rate),
            ),
        )
        pruner.pruning_model()
        summary_writer.add_scalar("remain_rate", pruner.cur_rate, prune_step + 1)
        summary_writer.add_scalar("cutoff", pruner.last_cutoff, prune_step + 1)

        pruner.save(
            os.path.join(
                args.save_path, "{}_{}.th".format(pruner.prune_times, pruner.cur_rate)
            )
        )
Exemple #11
0
def train(ibis_data,
          input_shape=(264, 264, 264),
          model_prefix=None,
          check_point=True,
          save_final=True,
          only_aa=False,
          only_atom=False,
          non_geom_features=False,
          use_deepsite_features=False,
          expand_atom=False,
          num_workers=None,
          num_epochs=30,
          batch_size=20,
          shuffle=True,
          use_gpu=True,
          initial_learning_rate=0.0001,
          learning_rate_drop=0.5,
          learning_rate_epochs=10,
          lr_decay=4e-2,
          data_split=0.8,
          course_grained=False,
          no_batch_norm=False,
          use_resnet_unet=False,
          unclustered=False,
          undersample=False,
          oversample=False,
          nFeatures=None,
          allow_feature_combos=False,
          bs_feature=None,
          bs_feature2=None,
          bs_features=None,
          stripes=False,
          data_parallel=False,
          dropout_depth=False,
          dropout_width=False,
          dropout_p=0.5,
          wide_model=False,
          cellular_organisms=False,
          autoencoder=False,
          checkpoint_callback=None):
    if model_prefix is None:
        model_prefix = "./molmimic_model_{}".format(
            datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))

    if num_workers is None:
        num_workers = multiprocessing.cpu_count() - 1

    since = time.time()

    if ibis_data == "spheres":
        from torch_loader import SphereDataset
        nFeatures = nFeatures or 3
        datasets = SphereDataset.get_training_and_validation(
            input_shape,
            cnt=1,
            n_samples=1000,
            nFeatures=nFeatures,
            allow_feature_combos=allow_feature_combos,
            bs_feature=bs_feature,
            bs_feature2=bs_feature2,
            bs_features=bs_features,
            stripes=stripes,
            data_split=0.99)
        validation_batch_size = 1
        if bs_features is not None:
            nClasses = len(bs_features) + 1
        else:
            nClasses = 2
    elif os.path.isfile(ibis_data):
        dataset = IBISDataset
        print allow_feature_combos, nFeatures
        if allow_feature_combos and nFeatures is not None:
            random_features = (nFeatures, allow_feature_combos, bs_feature,
                               bs_feature2)
        elif not allow_feature_combos and nFeatures is not None:
            random_features = (nFeatures, False, bs_feature, bs_feature2)
        elif allow_feature_combos and nFeatures is None:
            random_features = None
            print "ignoring --allow-feature-combos"
        else:
            random_features = None

        datasets = dataset.get_training_and_validation(
            ibis_data,
            input_shape=input_shape,
            only_aa=only_aa,
            only_atom=only_atom,
            non_geom_features=non_geom_features,
            use_deepsite_features=use_deepsite_features,
            data_split=data_split,
            course_grained=course_grained,
            oversample=oversample,
            undersample=undersample,
            cellular_organisms=cellular_organisms,
            random_features=random_features)
        nFeatures = datasets["train"].get_number_of_features()
        nClasses = 2 if not autoencoder else nFeatures

        validation_batch_size = batch_size
    else:
        raise RuntimeError("Invalid training data")

    if num_workers % 2 == 0:
        num_workers -= 1
    num_workers /= 2
    num_workers = 6

    dataloaders = {name:dataset.get_data_loader(
        batch_size if dataset.train else validation_batch_size,
        shuffle,
        num_workers) \
        for name, dataset in datasets.iteritems()}

    dtype = 'torch.cuda.FloatTensor' if torch.cuda.is_available(
    ) else 'torch.FloatTensor'

    if use_resnet_unet:
        model = ResNetUNet(nFeatures,
                           nClasses,
                           dropout_depth=dropout_depth,
                           dropout_width=dropout_width,
                           dropout_p=dropout_p,
                           wide_model=wide_model)
    else:
        model = UNet3D(nFeatures, nClasses, batchnorm=not no_batch_norm)

    if data_parallel:
        model = torch.nn.DataParallel(model)

    model.type(dtype)

    optimizer = SGD(model.parameters(),
                    lr=initial_learning_rate,
                    momentum=0.999,
                    weight_decay=1e-4,
                    nesterov=True)

    scheduler = LambdaLR(optimizer, lambda epoch: math.exp(
        (1 - epoch) * lr_decay))

    check_point_model_file = "{}_checkpoint_model.pth".format(model_prefix)
    check_point_epoch_file = "{}_checkpoint_epoch.pth".format(model_prefix)
    if check_point and os.path.isfile(
            check_point_model_file) and os.path.isfile(check_point_epoch_file):
        start_epoch = torch.load(check_point_epoch_file)
        print "Restarting at epoch {} from {}".format(start_epoch + 1,
                                                      check_point_model_file)
        model.load_state_dict(torch.load(check_point_model_file))
    else:
        start_epoch = 0

    inputSpatialSize = torch.LongTensor(input_shape)

    draw_graph = True

    mlog = MeterLogger(nclass=nClasses,
                       title="Sparse 3D UNet",
                       server="cn4216")

    #Start clean
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data')
                                        and torch.is_tensor(obj.data)):
                #print type(obj), obj.size()
                del obj
        except (SystemExit, KeyboardInterrupt):
            raise
        except Exception as e:
            pass

    for epoch in xrange(start_epoch, num_epochs):
        print "Epoch {}/{}".format(epoch, num_epochs - 1)
        print "-" * 10

        mlog.timer.reset()
        for phase in ['train', 'val']:
            datasets[phase].epoch = epoch
            num_batches = int(
                np.ceil(
                    len(datasets[phase]) /
                    float(batch_size if phase ==
                          "train" else validation_batch_size)))

            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            # Iterate over data.
            bar = tqdm(enumerate(dataloaders[phase]),
                       total=num_batches,
                       unit="batch",
                       desc="Loading data",
                       leave=True)
            for data_iter_num, data in bar:
                datasets[phase].batch = data_iter_num
                batch_weight = data.get("weight", None)

                if batch_weight is not None:
                    batch_weight = torch.from_numpy(batch_weight).float()
                    if use_gpu:
                        batch_weight = batch_weight.cuda()

                sample_weights = data.get("sample_weights", None)

                if sample_weights is not None:
                    sample_weights = torch.from_numpy(sample_weights).float()
                    if use_gpu:
                        sample_weights = sample_weights.cuda()

                if data["data"].__class__.__name__ == "InputBatch":
                    sparse_input = True
                    inputs = data["data"]
                    labels = data["truth"]
                    if use_gpu:
                        inputs = inputs.cuda().to_variable(requires_grad=True)
                        labels = labels.cuda().to_variable()
                    else:
                        inputs = inputs.to_variable(requires_grad=True)
                        labels = labels.to_variable()

                elif isinstance(data["data"], (list, tuple)):
                    sparse_input = True
                    inputs = scn.InputBatch(3, inputSpatialSize)
                    labels = scn.InputBatch(3, inputSpatialSize)

                    if isinstance(data["data"][0], np.ndarray):
                        long_tensor = lambda arr: torch.from_numpy(arr).long()
                        float_tensor = lambda arr: torch.from_numpy(arr).float(
                        )
                    elif isinstance(data["data"][0], (list, tuple)):
                        long_tensor = lambda arr: torch.LongTensor(arr)
                        float_tensor = lambda arr: torch.FloatTensor(arr)
                    else:
                        raise RuntimeError("invalid datatype")

                    for sample, (indices, features, truth, id) in enumerate(
                            izip(data["indices"], data["data"], data["truth"],
                                 data["id"])):
                        inputs.addSample()
                        labels.addSample()

                        try:
                            indices = long_tensor(indices)
                            features = float_tensor(features)
                            truth = float_tensor(truth)
                        except RuntimeError as e:
                            print e
                            continue

                        try:
                            inputs.setLocations(
                                indices, features,
                                0)  #Use 1 to remove duplicate coords?
                            labels.setLocations(indices, truth, 0)
                        except AssertionError:
                            print "Error with PDB:", id
                            with open("bad_pdbs.txt", "a") as f:
                                print >> f, id

                    del data
                    del indices
                    del truth

                    inputs.precomputeMetadata(1)

                    if use_gpu:
                        inputs = inputs.cuda()
                        labels = labels.cuda()

                    inputs = inputs.to_variable(requires_grad=True)
                    labels = labels.to_variable()

                elif isinstance(data["data"], torch.FloatTensor):
                    #Input is dense
                    print "Input is Dense"
                    sparse_input = False
                    if use_gpu:
                        inputs = inputs.cuda()
                        labels = labels.cuda()
                    inputs = Variable(data["data"], requires_grad=True)
                    inputs = scn.DenseToSparse(3)(inputs)
                    try:
                        inputs = inputs.cuda().to_variable(requires_grad=True)
                    except:
                        pass
                    labels = Variable(data["truth"])

                else:
                    raise RuntimeError("Invalid data from dataset")

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                try:
                    outputs = model(inputs)
                except AssertionError:
                    print nFeatures, inputs
                    raise

                if sparse_input:
                    use_size_average = False
                    weight = sample_weights if use_size_average else batch_weight

                    loss_fn = torch.nn.CrossEntropyLoss(weight=weight)

                    loss = loss_fn(outputs, torch.max(labels.features, 1)[1])

                    if draw_graph:
                        var_dot = dot.make_dot(loss)
                        var_dot.render('SparseUnet3dCNN_graph.pdf')
                        draw_graph = False
                        del var_dot

                else:
                    outputs = scn.SparseToDense(3, 1)(outputs)
                    criterion = DiceLoss(size_average=False)
                    loss = criterion(outputs.cpu(), labels.cpu(
                    ))  #, inputs.getSpatialLocations(), scaling)
                    stats.update(outputs.data.cpu().view(-1),
                                 labels.data.cpu().view(-1), loss.data[0])

                mlog.update_loss(loss, meter='loss')
                mlog.update_meter(outputs,
                                  torch.max(labels.features, 1)[1],
                                  meters={'accuracy', 'map'})
                add_to_logger(mlog,
                              "Train" if phase == "train" else "Test",
                              epoch,
                              outputs,
                              labels.features,
                              batch_weight,
                              n_classes=nClasses)

                # backward + optimize only if in training phase
                if phase == 'train':
                    a = list(model.parameters())[0].clone().data
                    loss.backward()
                    optimizer.step()
                    b = list(model.parameters())[0].clone().data
                    if torch.equal(a, b): print "NOT UPDATED"
                    del a
                    del b

                bar.set_description("{}: [{}][{}/{}]".format(
                    phase, epoch, data_iter_num + 1, num_batches))
                bar.set_postfix(loss="{:.4f} ({:.4f})".format(
                    mlog.meter["loss"].val, mlog.meter["loss"].mean),
                                dice_class1="{:.4f} ({:.4f})".format(
                                    mlog.meter["dice_class1"].val,
                                    mlog.meter["dice_class1"].mean),
                                weight_dice="{:.4f} ({:.4f})".format(
                                    mlog.meter["weighted_dice_wavg"].val,
                                    mlog.meter["weighted_dice_wavg"].mean),
                                refresh=False)
                bar.refresh()

                del inputs
                del labels
                del outputs
                del loss
                del loss_fn
                del batch_weight
                del sample_weights

                #Delete all unused objects on the GPU
                for obj in gc.get_objects():
                    try:
                        if torch.is_tensor(obj) or (hasattr(obj, 'data') and
                                                    torch.is_tensor(obj.data)):
                            #print type(obj), obj.size()
                            del obj
                    except (SystemExit, KeyboardInterrupt):
                        raise
                    except Exception as e:
                        pass

            statsfile, graphs = graph_logger(
                mlog, "Train" if phase == "train" else "Test", epoch)
            mlog.reset_meter(epoch, "Train" if phase == "train" else "Test")

            if check_point:
                torch.save(epoch, check_point_epoch_file)
                torch.save(model.state_dict(), check_point_model_file)
                if callable(checkpoint_callback):
                    checkpoint_callback(epoch, statsfile, graphs,
                                        check_point_epoch_file,
                                        check_point_model_file)
            elif callable(checkpoint_callback):
                checkpoint_callback(epoch, statsfile, graphs, None, None)

    #stats.plot_final()

    statsfile, graphs = graph_logger(mlog,
                                     "Train" if phase == "train" else "Test",
                                     epoch,
                                     final=True)

    time_elapsed = time.time() - since
    print 'Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed / 60, time_elapsed % 60)

    torch.save(model.state_dict(), "{}.pth".format(model_prefix))

    if callable(checkpoint_callback):
        checkpoint_callback(epoch, statsfile, graphs, check_point_epoch_file,
                            check_point_model_file)

    return model
Exemple #12
0
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=16,
                             pin_memory=True)

    model = Net(num_class=2, pretrained_path=model_path).cuda()

    flops, params = profile(model,
                            inputs=(torch.randn(1, 3, 32, 32).cuda(), ),
                            verbose=False)
    flops, params = clever_format([flops, params])
    print('# Model Params: {} FLOPs: {}'.format(params, flops))
    optimizer = SGD(model.fc.parameters(), lr=30.0, momentum=0.9)
    lr_scheduler = LambdaLR(optimizer,
                            lr_lambda=lambda i: 0.5 *
                            (math.cos(i * math.pi / epochs) + 1))
    loss_criterion = nn.CrossEntropyLoss()
    results = {
        'train_loss': [],
        'train_acc@1': [],
        'train_acc@5': [],
        'test_loss': [],
        'test_acc@1': [],
        'test_acc@5': []
    }

    best_acc = 0.0
    if not os.path.exists('results'):
        os.mkdir('results')
    for epoch in range(1, epochs + 1):
Exemple #13
0
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        curr_lambda = max(
            0.0,
            float(num_training_steps - current_step) /
            float(max(1, num_training_steps - num_warmup_steps)))
        # From unlock time and num_warmup_steps ahead, we go back to low LR and rise linearly
        if 0 <= (time_after_unlock :=
                 current_step - unlock_steps) < num_unlock_warmup_steps:
            return curr_lambda * time_after_unlock / max(
                1, num_unlock_warmup_steps)
        return curr_lambda

    return LambdaLR(optimizer, lr_lambda, last_epoch)


if __name__ == "__main__":
    import torch
    from torch.optim import SGD
    import matplotlib.pyplot as plt

    steps = 20_000
    optimizer = SGD([torch.tensor(1)], lr=1)

    scheduler = get_lr_scheduler(optimizer, int(.06 * steps), steps,
                                 steps // 2)

    lrs = list()
    for _ in range(steps):
Exemple #14
0
def supervised(nb_epoch: int, optimizer: Optimizer, **kwargs) -> list:
    lr_scheduler = LambdaLR(optimizer, get_lr_lambda(nb_epoch))
    return [lr_scheduler]
Exemple #15
0
def setup_data_and_model(params, model):
    # Variables that may not otherwise be assigned
    writer = perturbation_loader = generator = training_smiles = None

    # setup random seeds
    if params.val_seed is None: params.val_seed = params.seed
    set_seed_if(params.seed)

    exp_path = os.path.join(params.dump_path, params.exp_name)
    # create exp path if it doesn't exist
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    # create logger
    logger = create_logger(os.path.join(exp_path, 'train.log'), 0)
    pp = pprint.PrettyPrinter()
    logger.info("============ Initialized logger ============")
    logger.info("Random seed is {}".format(params.seed))
    if params.suppress_params is False:
        logger.info("\n".join("%s: %s" % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
        logger.info("Running command: %s" % 'python ' + ' '.join(sys.argv))
    logger.info("The experiment will be stored in %s\n" % exp_path)
    logger.info("")
    # load data
    train_data, val_dataset, train_loader, val_loader = load_graph_data(params)

    logger.info ('train_loader len is {}'.format(len(train_loader)))
    logger.info ('val_loader len is {}'.format(len(val_loader)))

    if params.num_binary_graph_properties > 0 and params.pretrained_property_embeddings_path:
        model.binary_graph_property_embedding_layer.weight.data = \
            torch.Tensor(np.load(params.pretrained_property_embeddings_path).T)
    if params.load_latest is True:
        load_prefix = 'latest'
    elif params.load_best is True:
        load_prefix = 'best'
    else:
        load_prefix = None

    if load_prefix is not None:
        if params.local_cpu is True:
            model.load_state_dict(torch.load(os.path.join(exp_path, '{}_model'.format(load_prefix)), map_location='cpu'))
        else:
            model.load_state_dict(torch.load(os.path.join(exp_path, '{}_model'.format(load_prefix))))
    if params.local_cpu is False:
        model = model.cuda()
    if params.gen_num_samples > 0:
        generator = GraphGenerator(train_data, model, params.gen_random_init, params.gen_num_iters, params.gen_predict_deterministically, params.local_cpu)
        with open(params.smiles_path) as f:
            smiles = f.read().split('\n')
            training_smiles = smiles[:int(params.smiles_train_split * len(smiles))]
            del smiles
    opt = get_optimizer(model.parameters(), params.optimizer)
    if load_prefix is not None:
        opt.load_state_dict(torch.load(os.path.join(exp_path, '{}_opt_sd'.format(load_prefix))))

    lr = opt.param_groups[0]['lr']
    lr_lambda = lambda iteration: lr_decay_multiplier(iteration, params.warm_up_iters, params.decay_start_iter,
                                                      params.lr_decay_amount, params.lr_decay_frac,
                                                      params.lr_decay_interval, params.min_lr, lr)
    scheduler = LambdaLR(opt, lr_lambda)
    index_method = get_index_method()

    best_loss = 9999
    if params.tensorboard:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(exp_path)

    total_iter, grad_accum_iters = params.first_iter, 0

    return params, model, opt, scheduler, train_data, train_loader, val_dataset, val_loader, perturbation_loader,\
           generator, index_method, exp_path, training_smiles, pp, logger, writer, best_loss, total_iter,\
           grad_accum_iters
Exemple #16
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(
        root=args.source_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=args.resize_ratio,
                                scale=(0.5, 1.)),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(
        root=args.target_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=(2., 2.),
                                scale=(0.5, 1.)),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(
        root=args.target_root,
        split='val',
        transforms=T.Compose([
            T.Resize(image_size=args.test_input_size,
                     label_size=args.test_output_size),
            T.NormalizeAndTranspose(),
        ]),
    )
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    num_classes = train_source_dataset.num_classes
    model = models.__dict__[args.arch](num_classes=num_classes).to(device)
    discriminator = Discriminator(num_classes=num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(model.get_parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    optimizer_d = Adam(discriminator.parameters(),
                       lr=args.lr_d,
                       betas=(0.9, 0.99))
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x:
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
        args.start_epoch = checkpoint['epoch'] + 1

    # define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label).to(device)
    dann = DomainAdversarialEntropyLoss(discriminator)
    interp_train = nn.Upsample(size=args.train_size[::-1],
                               mode='bilinear',
                               align_corners=True)
    interp_val = nn.Upsample(size=args.test_output_size[::-1],
                             mode='bilinear',
                             align_corners=True)

    # define visualization function
    decode = train_source_dataset.decode_target

    def visualize(image, pred, label, prefix):
        """
        Args:
            image (tensor): 3 x H x W
            pred (tensor): C x H x W
            label (tensor): H x W
            prefix: prefix of the saving image
        """
        image = image.detach().cpu().numpy()
        pred = pred.detach().max(dim=0)[1].cpu().numpy()
        label = label.cpu().numpy()
        for tensor, name in [
            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))),
             "image"), (decode(label), "label"), (decode(pred), "pred")
        ]:
            tensor.save(logger.get_image_path("{}_{}.png".format(prefix,
                                                                 name)))

    if args.phase == 'test':
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           visualize, args)
        print(confmat)
        return

    # start training
    best_iou = 0.
    for epoch in range(args.start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, model, interp_train,
              criterion, dann, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           None, args)
        print(confmat.format(train_source_dataset.classes))
        acc_global, acc, iu = confmat.compute()

        # calculate the mean iou over partial classes
        indexes = [
            train_source_dataset.classes.index(name)
            for name in train_source_dataset.evaluate_classes
        ]
        iu = iu[indexes]
        mean_iou = iu.mean()

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'lr_scheduler_d': lr_scheduler_d.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if mean_iou > best_iou:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
        best_iou = max(best_iou, mean_iou)
        print("Target: {} Best: {}".format(mean_iou, best_iou))

    logger.close()
def train(mode='train',
          train_path='train.conllx',
          model='dozat',
          dataset='conllx',
          dev_path='dev.conllx',
          test_path='test.conllx',
          ud=True,
          output_dir='output',
          emb_dim=0,
          char_emb_dim=0,
          char_model=None,
          tagger=None,
          batch_size=5000,
          n_iters=10,
          dropout_p=0.33,
          num_layers=1,
          print_every=1,
          eval_every=100,
          bi=True,
          var_drop=False,
          upos_pred=False,
          lr=0.001,
          adam_beta1=0.9,
          adam_beta2=0.999,
          weight_decay=0.,
          plateau=False,
          resume=False,
          lr_decay=1.0,
          lr_decay_steps=5000,
          clip=5.,
          momentum=0,
          optimizer='adam',
          glove=True,
          seed=42,
          dim=0,
          window_size=0,
          num_filters=0,
          **kwargs):

    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    cfg = locals().copy()

    torch.manual_seed(seed)
    np.random.seed(seed)

    # load data component
    if dataset == "conllx":
        dataset_obj = ConllXDataset
        fields = get_data_fields()
        _upos = None
        ud = False
    elif dataset == "conllu":
        dataset_obj = ConllUDataset
        fields = get_data_fields_conllu()
        _upos = fields['upos'][-1]
        ud = True
    else:
        raise NotImplementedError()

    _form = fields['form'][-1]
    _pos = fields['pos'][-1]
    _chars = fields['chars'][-1]

    train_dataset = dataset_obj(train_path, fields)
    dev_dataset = dataset_obj(dev_path, fields)
    test_dataset = dataset_obj(test_path, fields)

    logger.info("Loaded %d train examples" % len(train_dataset))
    logger.info("Number of train tokens: %d" % train_dataset.n_tokens)
    logger.info("Loaded %d dev examples" % len(dev_dataset))
    logger.info("Number of train tokens: %d" % dev_dataset.n_tokens)
    logger.info("Loaded %d test examples" % len(test_dataset))
    logger.info("Number of train tokens: %d" % test_dataset.n_tokens)

    form_vocab_path = os.path.join(output_dir, 'vocab.form.pth.tar')
    pos_vocab_path = os.path.join(output_dir, 'vocab.pos.pth.tar')
    char_vocab_path = os.path.join(output_dir, 'vocab.char.pth.tar')

    if not resume:
        # build vocabularies
        # words have a min frequency of 2 to be included; others become <unk>
        # words without a Glove vector are initialized ~ N(0, 0.5) mimicking Glove

        # Note: this requires the latest torchtext development version from Github.
        # - git clone https://github.com/pytorch/text.git torchtext
        # - cd torchtext
        # - python setup.py build
        # - python setup.py install

        def unk_init(x):
            # return 0.01 * torch.randn(x)
            return torch.zeros(x)

        if glove:
            logger.info("Using Glove vectors")
            glove_vectors = GloVe(name='6B', dim=100)
            _form.build_vocab(train_dataset,
                              min_freq=2,
                              unk_init=unk_init,
                              vectors=glove_vectors)
            n_unks = 0
            unk_set = set()
            # for now, set UNK words manually
            # (torchtext does not seem to support it yet)
            for i, token in enumerate(_form.vocab.itos):
                if token not in glove_vectors.stoi:
                    n_unks += 1
                    unk_set.add(token)
                    _form.vocab.vectors[i] = unk_init(emb_dim)
            # print(n_unks, unk_set)

        else:
            _form.build_vocab(train_dataset, min_freq=2)

        _pos.build_vocab(train_dataset)
        if ud:
            _upos.build_vocab(train_dataset)
        _chars.build_vocab(train_dataset)

        # save vocabularies
        torch.save(_form.vocab, form_vocab_path)
        torch.save(_pos.vocab, pos_vocab_path)
        torch.save(_chars.vocab, char_vocab_path)

    else:
        # load vocabularies
        _form.vocab = torch.load(form_vocab_path)
        _pos.vocab = torch.load(pos_vocab_path)
        _chars.vocab = torch.load(char_vocab_path)

    print("First 10 vocabulary entries, words: ",
          " ".join(_form.vocab.itos[:10]))
    print("First 10 vocabulary entries, pos tags: ",
          " ".join(_pos.vocab.itos[:10]))
    print("First 10 vocabulary entries, chars: ",
          " ".join(_chars.vocab.itos[:10]))
    if upos_pred:
        print("First 10 vocabulary entries, upos tags: ",
              " ".join(_upos.vocab.itos[:10]))

    n_words = len(_form.vocab)
    n_tags = len(_pos.vocab)
    if upos_pred:
        n_utags = len(_upos.vocab)
    else:
        n_utags = 0
    n_chars = len(_chars.vocab)

    def batch_size_fn(new, count, sofar):
        return len(new.form) + 1 + sofar

    # iterators
    train_iter = Iterator(train_dataset,
                          batch_size,
                          train=True,
                          sort_within_batch=True,
                          batch_size_fn=batch_size_fn,
                          device=device)
    dev_iter = Iterator(dev_dataset,
                        32,
                        train=False,
                        sort_within_batch=True,
                        device=device)
    test_iter = Iterator(test_dataset,
                         32,
                         train=False,
                         sort_within_batch=True,
                         device=device)

    # if n_iters or eval_every are negative, we set them to that many
    # number of epochs
    iters_per_epoch = (len(train_dataset) // batch_size) + 1
    if eval_every < 0:
        logger.info("Setting eval_every to %d epoch(s) = %d iters" %
                    (-1 * eval_every, -1 * eval_every * iters_per_epoch))
        eval_every = iters_per_epoch * eval_every

    if n_iters < 0:
        logger.info("Setting n_iters to %d epoch(s) = %d iters" %
                    (-1 * n_iters, -1 * n_iters * iters_per_epoch))
        n_iters = -1 * n_iters * iters_per_epoch

    # load up the model
    if upos_pred:
        upos_vocab = _upos.vocab
    else:
        upos_vocab = None
    model = Tagger(n_words=n_words,
                   n_tags=n_tags,
                   n_utags=n_utags,
                   n_chars=n_chars,
                   form_vocab=_form.vocab,
                   char_vocab=_chars.vocab,
                   pos_vocab=_pos.vocab,
                   upos_vocab=upos_vocab,
                   **cfg)

    # set word vectors
    if glove:
        _form.vocab.vectors = _form.vocab.vectors / torch.std(
            _form.vocab.vectors)
        # print(torch.std(_form.vocab.vectors))
        model.encoder.embedding.weight.data.copy_(_form.vocab.vectors)
        model.encoder.embedding.weight.requires_grad = True

    model = model.cuda() if use_cuda else model

    start_iter = 1
    best_iter = 0
    best_pos_acc = -1.
    test_pos_acc = -1.

    # optimizer and learning rate scheduler
    trainable_parameters = [p for p in model.parameters() if p.requires_grad]
    if optimizer == 'sgd':
        optimizer = torch.optim.SGD(trainable_parameters,
                                    lr=lr,
                                    momentum=momentum)
    else:
        optimizer = torch.optim.Adam(trainable_parameters,
                                     lr=lr,
                                     betas=(adam_beta1, adam_beta2))

    # learning rate schedulers
    if not plateau:
        scheduler = LambdaLR(
            optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))
    else:
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='max',
                                      factor=0.75,
                                      patience=5,
                                      min_lr=1e-4)

    # load model and vocabularies if resuming
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_iter = checkpoint['iter_i']
            best_pos_acc = checkpoint['best_pos_acc']
            test_pos_acc = checkpoint['test_pos_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (iter {})".format(
                resume, checkpoint['iter_i']))
        else:
            print("=> no checkpoint found at '{}'".format(resume))

    print_parameters(model)

    # print some stuff just for fun
    logger.info("Most common words: %s" % _form.vocab.freqs.most_common(20))
    logger.info("Word vocab size: %s" % n_words)
    logger.info("Most common XPOS-tags: %s" % _pos.vocab.freqs.most_common())
    logger.info("POS vocab size: %s" % n_tags)
    # logger.info("Most common chars: %s" % _chars.nesting_field.vocab.freqs.most_common())
    logger.info("Chars vocab size: %s" % n_chars)

    print("First training example:")
    print_example(train_dataset[0])

    print("First dev example:")
    print_example(dev_dataset[0])

    print("First test example:")
    print_example(test_dataset[0])

    logger.info("Training starts..")
    upos_var, morph_var = None, None
    for iter_i in range(start_iter, n_iters + 1):

        if not ud:
            epoch_done = (train_dataset.n_tokens // batch_size)
        else:
            epoch_done = (train_dataset.n_tokens // batch_size)

        # if not plateau and iter_i % epoch_done == 0:  # TODO: fix
        #   scheduler.step()
        scheduler.step()
        model.train()

        batch = next(iter(train_iter))
        form_var, lengths = batch.form

        pos_var, pos_lengths = batch.pos
        if upos_pred:
            upos_var, _ = batch.upos
        else:
            upos_var = None

        char_var, sentence_lengths, word_lengths = batch.chars
        lengths = lengths.view(-1).tolist()

        result = model(form_var=form_var,
                       char_var=char_var,
                       pos_var=pos_var,
                       lengths=lengths,
                       word_lengths=word_lengths,
                       pos_lengths=pos_lengths)

        if upos_pred:
            targets = dict(pos=batch.pos, upos=batch.upos)
        else:
            targets = dict(pos=batch.pos, upos=None)

        all_losses = model.get_loss(scores=result, targets=targets)

        loss = all_losses['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        optimizer.zero_grad()

        if iter_i % print_every == 0:

            # get scores for this batch
            upos_predictions = []
            if model.tagger == "linear" or model.tagger == "mlp":
                if model.upos_pred:
                    upos_predictions = result['output']['upos'].max(2)[1]
                    pos_predictions = result['output']['xpos'].max(2)[1]
                else:
                    pos_predictions = result['output']['xpos'].max(2)[1]
            else:
                pos_predictions = result['sequence']

            predictions = dict(pos=pos_predictions, upos=upos_predictions)
            if model.upos_pred:
                targets = dict(pos=batch.pos, upos=batch.upos)
            else:
                targets = dict(pos=batch.pos, upos=None)

            pos_acc, upos_acc = model.get_accuracy(predictions=predictions,
                                                   targets=targets)

            if not plateau:
                lr = scheduler.get_lr()[0]
            else:
                lr = [group['lr'] for group in optimizer.param_groups][0]

            fmt = "Iter %08d loss %8.4f pos-acc %5.2f upos-acc %5.2f lr %.5f"

            logger.info(fmt % (iter_i, loss, pos_acc, upos_acc, lr))

        if iter_i % eval_every == 0:

            # parse dev set and save to file for official evaluation
            dev_out_path = 'dev.iter%08d.conll' % iter_i
            dev_out_path = os.path.join(output_dir, dev_out_path)
            predict_and_save(dataset=dev_dataset,
                             model=model,
                             dataset_path=dev_path,
                             out_path=dev_out_path)

            _dev_pos_acc, _dev_upos_acc = get_pos_acc(dev_path, dev_out_path,
                                                      ud)

            logger.info("Evaluation dev Iter %08d "
                        "pos-acc %5.2f upos-acc %5.2f" %
                        (iter_i, _dev_pos_acc, _dev_upos_acc))

            # parse test set and save to file for official evaluation
            test_out_path = 'test.iter%08d.conll' % iter_i
            test_out_path = os.path.join(output_dir, test_out_path)
            predict_and_save(dataset=test_dataset,
                             model=model,
                             dataset_path=test_path,
                             out_path=test_out_path)
            _test_pos_acc, _test_upos_acc = get_pos_acc(
                test_path, test_out_path, ud)

            logger.info("Evaluation test Iter %08d "
                        "pos-acc %5.2f upos-acc %5.2f" %
                        (iter_i, _test_pos_acc, _test_upos_acc))

            if plateau:
                scheduler.step(_dev_pos_acc)

            if _dev_pos_acc > best_pos_acc:
                best_iter = iter_i
                best_pos_acc = _dev_pos_acc
                test_pos_acc = _test_pos_acc
                is_best = True
            else:
                is_best = False

            save_checkpoint(
                output_dir, {
                    'iter_i': iter_i,
                    'state_dict': model.state_dict(),
                    'best_iter': best_iter,
                    'test_pos_acc': test_pos_acc,
                    'optimizer': optimizer.state_dict(),
                }, False)

    logger.info("Done Training")
    logger.info(
        "Best model Iter %08d Dev POS-acc %12.4f Test POS-acc %12.4f " %
        (best_iter, best_pos_acc, test_pos_acc))
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        # Initialize auxiliary tasks
        observation_space = self.envs.observation_spaces[0]
        aux_cfg = self.config.RL.AUX_TASKS
        init_aux_tasks, num_recurrent_memories, aux_task_strings = \
            self._setup_auxiliary_tasks(aux_cfg, ppo_cfg, task_cfg, observation_space)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            observation_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_memories=num_recurrent_memories)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg,
                                       init_aux_tasks)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict[
                    "extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:

            for update in range(start_updates, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar(
                    "entropy",
                    dist_entropy,
                    count_steps,
                )

                writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)

                losses = [value_loss, action_loss] + aux_task_losses
                writer.add_scalars(
                    "losses",
                    {
                        k: l
                        for l, k in zip(losses, ["value", "policy"] +
                                        aux_task_strings)
                    },
                    count_steps,
                )

                writer.add_scalars(
                    "aux_weights",
                    {k: l
                     for l, k in zip(aux_weights, aux_task_strings)},
                    count_steps,
                )

                writer.add_scalar(
                    "success",
                    deltas["success"] / deltas["count"],
                    count_steps,
                )

                # Log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info(
                        "update: {}\tvalue_loss: {}\t action_loss: {}\taux_task_loss: {} \t aux_entropy {}"
                        .format(update, value_loss, action_loss,
                                aux_task_losses, aux_dist_entropy))
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(
                        f"{self.checkpoint_prefix}.{count_checkpoints}.pth",
                        dict(step=count_steps))
                    count_checkpoints += 1

        self.envs.close()