Exemple #1
0
def main():
    train_file_paths = ["data/yelp/sentiment.train.0", "data/yelp/sentiment.train.1"]
    dev_file_paths = ["data/yelp/sentiment.dev.0", "data/yelp/sentiment.dev.1"]
    test_file_paths = ["data/yelp/sentiment.test.0", "data/yelp/sentiment.test.1"]

    word2idx, idx2word, embedding = build_vocab(train_file_paths,
                                                glove_path=config.glove_path)
    if config.train:
        # prepare data loader for training
        train_loader = get_loader(train_file_paths[1],
                                  train_file_paths[0],
                                  word2idx,
                                  debug=config.debug,
                                  batch_size=config.batch_size)
        # prepare data loader for evaluation
        dev_loader = get_loader(dev_file_paths[1],
                                dev_file_paths[0],
                                word2idx,
                                shuffle=False,
                                debug=config.debug,
                                batch_size=config.batch_size)
        data_loaders = [train_loader, dev_loader]
        trainer = Trainer(embedding, data_loaders)
        trainer.train()
    else:
        test_loader = get_loader(test_file_paths[1],
                                 test_file_paths[0],
                                 word2idx,
                                 debug=config.debug,
                                 shuffle=False,
                                 batch_size=16)
        data_loaders = [test_loader]
        trainer = Trainer(embedding, data_loaders)
        trainer.inference(config.model_path, config.output_dir, idx2word)
    def __init__(self, model_path=None):
        # load dictionary and embedding file
        with open(config.embedding, "rb") as f:
            embedding = pickle.load(f)
            embedding = torch.Tensor(embedding).to(config.device)
        with open(config.word2idx_file, "rb") as f:
            word2idx = pickle.load(f)

        # train, dev loader
        print("load train data")
        self.train_loader = get_loader(config.train_src_file,
                                       config.train_trg_file,
                                       word2idx,
                                       use_tag=config.use_tag,
                                       batch_size=config.batch_size,
                                       debug=config.debug)
        self.dev_loader = get_loader(config.dev_src_file,
                                     config.dev_trg_file,
                                     word2idx,
                                     use_tag=config.use_tag,
                                     batch_size=128,
                                     debug=config.debug)

        train_dir = os.path.join("./save", "seq2seq")
        self.model_dir = os.path.join(train_dir, "train_%d" % int(time.strftime("%m%d%H%M%S")))
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        self.model = Seq2seq(embedding, config.use_tag, model_path=model_path)
        params = list(self.model.encoder.parameters()) \
                 + list(self.model.decoder.parameters())

        self.lr = config.lr
        self.optim = optim.SGD(params, self.lr, momentum=0.8)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
    def __init__(self, args):
        # load dictionary and embedding file
        with open(config.embedding, "rb") as f0:
            embedding = pickle.load(f0)
            embedding = torch.tensor(embedding,
                                     dtype=torch.float).to(config.device)
        with open(config.entity_embedding, "rb") as f1:
            ent_embedding = pickle.load(f1)
            ent_embedding = torch.tensor(ent_embedding,
                                         dtype=torch.float).to(config.device)
        with open(config.relation_embedding, "rb") as f2:
            rel_embedding = pickle.load(f2)
            rel_embedding = torch.tensor(rel_embedding,
                                         dtype=torch.float).to(config.device)
        with open(config.word2idx_file, "rb") as f:
            word2idx = pickle.load(f)
        with open(config.ent2idx_file, "rb") as g:
            ent2idx = pickle.load(g)
        with open(config.rel2idx_file, "rb") as h:
            rel2idx = pickle.load(h)

        # train, dev loader
        print("load train data")
        self.train_loader = get_loader(config.train_src_file,
                                       config.train_trg_file,
                                       config.train_csfile,
                                       word2idx,
                                       use_tag=True,
                                       batch_size=config.batch_size,
                                       debug=config.debug)
        self.dev_loader = get_loader(config.dev_src_file,
                                     config.dev_trg_file,
                                     config.dev_csfile,
                                     word2idx,
                                     use_tag=True,
                                     batch_size=128,
                                     debug=config.debug)

        train_dir = "./save"
        self.model_dir = os.path.join(
            train_dir, "train_%d" % int(time.strftime("%m%d%H%M%S")))
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        self.model = Seq2seq(embedding, ent_embedding, rel_embedding)
        # self.model = nn.DataParallel(self.model)
        self.model = self.model.to(config.device)

        if len(args.model_path) > 0:
            print("load check point from: {}".format(args.model_path))
            state_dict = torch.load(args.model_path, map_location="cpu")
            self.model.load_state_dict(state_dict)

        params = self.model.parameters()

        self.lr = config.lr
        self.optim = optim.SGD(params, self.lr, momentum=0.8)
        # self.optim = optim.Adam(params)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
    def __init__(self, model_path, output_dir):
        with open(config.word2idx_file, "rb") as f:
            word2idx = pickle.load(f)

        self.output_dir = output_dir
        self.test_data = open(config.test_trg_file, "r").readlines()
        self.data_loader = get_loader(config.test_src_file,
                                      config.test_trg_file,
                                      word2idx,
                                      batch_size=1,
                                      use_tag=True,
                                      shuffle=False)

        self.tok2idx = word2idx
        self.idx2tok = {idx: tok for tok, idx in self.tok2idx.items()}
        self.model = Seq2seq()
        state_dict = torch.load(model_path)
        self.model.load_state_dict(state_dict)
        self.model.eval()
        self.moddel = self.model.to(config.device)
        self.pred_dir = os.path.join(output_dir, "generated.txt")
        self.golden_dir = os.path.join(output_dir, "golden.txt")
        self.src_file = os.path.join(output_dir, "src.txt")

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        # dummy file for evaluation
        with open(self.src_file, "w") as f:
            for i in range(len(self.data_loader)):
                f.write(str(i) + "\n")
Exemple #5
0
    def __init__(self, model_path, output_dir):
        self.logger = logging.getLogger('paragraph-level')

        self.output_dir = output_dir
        self.test_data = open(config.test_trg_file, "r").readlines()
        self.data_loader = get_loader(config.test_src_file,
                                      config.test_trg_file,
                                      config.test_ans_file,
                                      batch_size=1,
                                      use_tag=False,
                                      shuffle=False)

        self.tokenizer = BertTokenizer.from_pretrained(r'MTBERT/vocab.txt')
        self.model_config = BertConfig.from_pretrained('MTBERT')
        self.model = Seq2seq()
        if config.use_gpu:
            state_dict = torch.load(model_path, map_location=config.device)
        else:
            state_dict = torch.load(model_path, map_location='cpu')

        self.model.load_state_dict(state_dict)
        self.model.eval()
        if config.use_gpu:
            self.moddel = self.model.to(config.device)
        self.pred_dir = 'result/pointer_maxout_ans/generated.txt'
        self.golden_dir = 'result/pointer_maxout_ans/golden.txt'
        self.src_file = 'result/pointer_maxout_ans/src.txt'

        # dummy file for evaluation
        with open(self.src_file, "w") as f:
            for i in range(len(self.data_loader)):
                f.write(str(i) + "\n")
Exemple #6
0
    def __init__(self, args):
        self.logger = logging.getLogger('paragraph-level')

        # train, dev loader
        print("load train data")
        self.train_loader = get_loader(config.train_src_file,
                                       config.train_trg_file,
                                       config.train_ans_file,
                                       batch_size=config.batch_size,
                                       debug=config.debug,
                                       shuffle=True)
        self.dev_loader = get_loader(config.dev_src_file,
                                     config.dev_trg_file,
                                     config.dev_ans_file,
                                     batch_size=128,
                                     debug=config.debug)

        train_dir = os.path.join(config.file_path + "save", "seq2seq")
        self.model_dir = os.path.join(
            train_dir, "train_%d" % int(time.strftime("%m%d%H%M%S")))
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        self.model = Seq2seq()
        if config.use_gpu:
            self.model = self.model.to(config.device)

        if len(args.model_path) > 0:
            print("load check point from: {}".format(args.model_path))
            state_dict = torch.load(args.model_path, map_location="cpu")
            self.model.load_state_dict(state_dict)

        params = self.model.parameters()
        bert_params = self.model.bert_encoder.named_parameters()
        for name, param in bert_params:
            param.requires_grad = False
        base_params = filter(lambda p: p.requires_grad,
                             self.model.parameters())
        self.lr = config.lr
        self.optim = optim.SGD(base_params, self.lr, momentum=0.8)
        # self.optim = optim.Adam(params)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
    def __init__(self, model_path, output_dir):
        with open(config.word2idx_file, "rb") as f:
            word2idx = pickle.load(f)

        self.output_dir = output_dir
        self.test_data = open(config.test_trg_file, "r").readlines()
        self.data_loader = get_loader(config.test_src_file,
                                      config.test_trg_file,
                                      word2idx,
                                      batch_size=1,
                                      use_tag=config.use_tag,
                                      shuffle=False)

        self.tok2idx = word2idx
        self.idx2tok = {idx: tok for tok, idx in self.tok2idx.items()}
        self.model = Seq2seq(model_path=model_path)
        self.pred_dir = output_dir + "/generated.txt"
        self.golden_dir = output_dir + "/golden.txt"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
Exemple #8
0
def main():
    dataloader, vocab_size, n_class = get_loader(batch_size=2)

    cap_model = CaptionModel(
        embed_dim=512,
        model_name='resnet',
        total_vocab=vocab_size,
        n_class=n_class,
        hidden_size=512,
        num_layers=2)

    cap_model = cap_model.cuda()

    optimizer = optim.Adam(cap_model.get_train_param())

    train(
        epochs=100,
        save_point=10,
        model=cap_model,
        dataloader=dataloader,
        criterion=get_performance,
        optimizer=optimizer,
        print_step=100)
Exemple #9
0
def main_worker(gpu, config):
    # GPU is assigned
    config.gpu = gpu
    config.rank = gpu
    print(f'Launching at GPU {gpu}')

    if config.distributed:
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:9001',
            # init_method="env://",
            world_size=config.world_size,
            rank=config.rank)

    if config.clustering:
        feat_dim = config.emb_dim  # feat_dim
        imsize = config.resize_input_size  # imsize
        n_centroids = config.n_centroids
        n_iter = config.n_iter
        encoder = config.encoder
        cluster_src = config.cluster_src

        centroid_dir = Path('../datasets/cluster_centroids/').resolve()
        if config.im_ratio == 'original':
            centroid_path = centroid_dir.joinpath(
                f'{encoder}_{cluster_src}_centroids{n_centroids}_iter{n_iter}_d{feat_dim}_grid{config.n_grid}.npy'
            )
        else:
            centroid_path = centroid_dir.joinpath(
                f'{encoder}_{cluster_src}_centroids{n_centroids}_iter{n_iter}_d{feat_dim}_grid{config.n_grid}_imsize{imsize}.npy'
            )
        centroids = np.load(centroid_path)

        Emb = nn.Embedding.from_pretrained(torch.from_numpy(centroids),
                                           freeze=True)
    else:
        Emb = None

    if config.classifier is None:
        E = None
    elif config.classifier == 'resnet101':
        E = ResNetEncoder('resnet101')
    elif config.classifier == 'resnet50':
        E = ResNetEncoder('resnet50')

    G = Generator(
        base_dim=config.g_base_dim,
        emb_dim=config.emb_dim,
        mod_dim=config.y_mod_dim,
        n_channel=config.n_channel,
        target_size=config.resize_target_size,
        extra_layers=config.g_extra_layers,
        init_H=config.n_grid,
        init_W=config.n_grid,
        norm_type=config.g_norm_type,
        SN=config.SN,
        codebook_dim=config.codebook_dim,
    )

    if config.gan:
        D = Discriminator(base_dim=config.d_base_dim,
                          emb_dim=config.emb_dim,
                          n_channel=config.n_channel,
                          target_size=config.resize_target_size,
                          extra_layers=config.d_extra_layers,
                          init_H=config.n_grid,
                          init_W=config.n_grid,
                          SN=config.SN,
                          ACGAN=config.ACGAN,
                          n_classes=config.n_centroids)
        if config.ACGAN:
            D.emb_classifier.weight = Emb.weight
    else:
        D = None

    # Logging
    if config.gpu == 0:
        logger = logging.getLogger('mylogger')
        file_handler = logging.FileHandler(config.log_dir.joinpath('log.txt'))
        stream_handler = logging.StreamHandler()
        logger.addHandler(file_handler)
        # logger.addHandler(stream_handler)
        logger.setLevel(logging.DEBUG)

        print('#===== (Trainable) Parameters =====#')

        def count_parameters(model):
            return sum(p.numel() for p in model.parameters()
                       if p.requires_grad)

        n_params = 0
        for model_name, model in [('E', E), ('G', G), ('D', D), ('Emb', Emb)]:
            if model is not None:
                # print(model)
                logger.info(model)
                # for name, p in model.named_parameters():
                #     print(name, '\t', list(p.size()))
                n_param = count_parameters(model)
                log_str = f'# {model_name} Parameters: {n_param}'
                print(log_str)
                logger.info(log_str)
                n_params += n_param
        log_str = f'# Total Parameters: {n_params}'
        logger.info(log_str)
        print(log_str)

        config.save(config.log_dir.joinpath('config.yaml'))

        # Save scripts for backup
        log_src_dir = config.log_dir.joinpath(f'src/')
        log_src_dir.mkdir(exist_ok=True)
        proj_dir = Path(__file__).resolve().parent
        for path in proj_dir.glob('*.py'):
            tgt_path = log_src_dir.joinpath(path.name)
            shutil.copy(path, tgt_path)
    else:
        logger = None

    if config.distributed:
        torch.cuda.set_device(config.gpu)

    if config.distributed:
        if 'bn' in config.g_norm_type:
            G = nn.SyncBatchNorm.convert_sync_batchnorm(G)
        G = G.cuda(config.gpu)

        params = G.parameters()

        g_optim = optim.Adam(
            params,
            lr=config.g_lr,
            betas=[config.g_adam_beta1, config.g_adam_beta2],
            eps=config.adam_eps,
        )

        if config.mixed_precision:
            G, g_optim = amp.initialize(G, g_optim, opt_level='O1')

        G = DDP(G,
                device_ids=[config.gpu],
                find_unused_parameters=True,
                broadcast_buffers=not config.SN)
    else:
        G = G.cuda()

        params = G.parameters()

        g_optim = optim.Adam(
            params,
            lr=config.g_lr,
            betas=[config.g_adam_beta1, config.g_adam_beta2],
            eps=config.adam_eps,
        )
        if config.multiGPU:
            G = nn.DataParallel(G)

    e_optim = None
    if config.classifier:
        if config.distributed:
            E = E.cuda(config.gpu)
        else:
            E = E.cuda()

        E = E.eval()
        if not config.distributed and config.multiGPU:
            E = nn.DataParallel(E)
    else:
        e_optim = None

    if config.gan:
        if config.distributed:
            D = D.cuda(config.gpu)

            d_optim = optim.Adam(
                D.parameters(),
                lr=config.d_lr,
                betas=[config.d_adam_beta1, config.d_adam_beta2],
                eps=config.adam_eps,
            )

            if config.mixed_precision:
                D, d_optim = amp.initialize(D, d_optim, opt_level='O1')

            D = DDP(D,
                    device_ids=[config.gpu],
                    find_unused_parameters=True,
                    broadcast_buffers=not config.SN)
        else:
            D = D.cuda()

            d_optim = optim.Adam(
                D.parameters(),
                lr=config.d_lr,
                betas=[config.d_adam_beta1, config.d_adam_beta2],
                eps=config.adam_eps,
            )
            if config.multiGPU:
                D = nn.DataParallel(D)
    else:
        d_optim = None
    if config.clustering:
        if config.distributed:
            Emb = Emb.cuda(config.gpu)

        else:
            Emb = Emb.cuda()
            if config.multiGPU:
                Emb = nn.DataParallel(Emb)

    train_transform = transforms.Compose([
        transforms.Resize((config.resize_input_size, config.resize_input_size),
                          interpolation=Image.LANCZOS),
    ])
    valid_transform = transforms.Compose([
        transforms.Resize((config.resize_input_size, config.resize_input_size),
                          interpolation=Image.LANCZOS),
    ])

    data_out = ['img']
    if config.clustering:
        data_out.append('cluster_id')

    train_set = 'mscoco_train'
    if config.run_minival:
        train_set = 'mscoco_minival'

    train_loader = get_loader(config,
                              train_set,
                              mode='train',
                              batch_size=config.batch_size,
                              distributed=config.distributed,
                              gpu=config.gpu,
                              workers=config.workers,
                              transform=train_transform,
                              topk=config.train_topk,
                              data_out=data_out)

    if config.distributed:
        valid_batch_size = config.batch_size
    else:
        valid_batch_size = config.batch_size // 4

    val_loader = get_loader(config,
                            'mscoco_minival',
                            mode='val',
                            batch_size=valid_batch_size,
                            distributed=config.distributed,
                            gpu=config.gpu,
                            workers=0,
                            transform=valid_transform,
                            topk=config.valid_topk,
                            data_out=data_out)

    trainer = Trainer(config, E, G, D, Emb, g_optim, d_optim, e_optim,
                      train_loader, val_loader, logger)
    trainer.train()
Exemple #10
0
def main():
    global best_acc

    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Data
    print(f'==> Preparing stl-10')
    # transform_train = transforms.Compose([
    #     dataset.RandomPadandCrop(32),
    #     dataset.RandomFlip(),
    #     dataset.ToTensor(),
    # ])
    # transform_val = transforms.Compose([
    #     dataset.ToTensor(),
    # ])

    size = int(96 * 0.9)

    transform_train = transforms.Compose([
        # dataset.RandomPadandCrop(32),
        # transforms.Resize((32, 32)),
        # transforms.RandomCrop(32),
        transforms.RandomCrop(size,
                              padding=None,
                              pad_if_needed=False,
                              fill=0,
                              padding_mode='constant'),
        transforms.RandomHorizontalFlip(p=0.5),
        # transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
    ])

    # train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val)
    # labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
    # unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
    # val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
    # test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0)

    train_unlabeled_set = stl10.get_dataset_all(transform=transform_train)
    unlabeled_trainloader = stl10.get_loader(train_unlabeled_set,
                                             batch_size=args.batch_size,
                                             num_workers=0)
    train_labeled_set = stl10.get_train_label_dataset(
        transform=transform_train)
    labeled_trainloader = stl10.get_loader(train_labeled_set,
                                           batch_size=args.batch_size,
                                           num_workers=0)
    test_dataset = stl10.get_test_dataset()
    val_loader = stl10.get_loader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=0)

    # Model
    print("==> creating WRN-28-2")

    def create_model(ema=False):
        model = models.WideResNet(num_classes=10)
        model = model.cuda()

        if ema:
            for param in model.parameters():
                param.detach_()

        return model

    model = create_model()
    ema_model = create_model(ema=True)

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    train_criterion = SemiLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    ema_optimizer = WeightEMA(model, ema_model, alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    title = 'noisy-cifar-10'
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names([
            'Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss',
            'Valid Acc.'
        ])

    writer = SummaryWriter(args.out)
    step = 0
    # test_accs = []
    # Train and val

    if args.train == 0:
        val_loss, val_acc = validate(val_loader,
                                     ema_model,
                                     criterion,
                                     0,
                                     use_cuda,
                                     mode='Valid Stats')
        return

    for epoch in range(start_epoch, args.epochs):

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_loss_x, train_loss_u = train(
            labeled_trainloader, unlabeled_trainloader, model, optimizer,
            ema_optimizer, train_criterion, epoch, use_cuda)
        _, train_acc = validate(labeled_trainloader,
                                ema_model,
                                criterion,
                                epoch,
                                use_cuda,
                                mode='Train Stats')
        val_loss, val_acc = validate(val_loader,
                                     ema_model,
                                     criterion,
                                     epoch,
                                     use_cuda,
                                     mode='Valid Stats')
        # test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')

        step = args.val_iteration * (epoch + 1)

        writer.add_scalar('losses/train_loss', train_loss, step)
        writer.add_scalar('losses/valid_loss', val_loss, step)
        # writer.add_scalar('losses/test_loss', test_loss, step)

        writer.add_scalar('accuracy/train_acc', train_acc, step)
        writer.add_scalar('accuracy/val_acc', val_acc, step)
        # writer.add_scalar('accuracy/test_acc', test_acc, step)

        # append logger file
        # logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc])
        logger.append(
            [train_loss, train_loss_x, train_loss_u, val_loss, val_acc])

        # save model
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'ema_state_dict': ema_model.state_dict(),
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best)
        # test_accs.append(test_acc)
    logger.close()
    writer.close()

    print('Best acc:')
    print(best_acc)
def main(args):
    savedir = "./saved-outputs/model" + str(args.base_idx) + "/"
    print('Preparing directory %s' % savedir)
    os.makedirs(savedir, exist_ok=True)
    with open(os.path.join(savedir, 'base_command.sh'), 'w') as f:
        f.write(' '.join(sys.argv))
        f.write('\n')
    
    trainloader, testloader = get_loader(args)
    
    config = CONFIGS['ViT-B_16']
    num_classes = 100
    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
    modeldir = "./cifar100-100_500_seed_" + str(args.base_idx) + "/"  
    modelname = "cifar100-100_500_seed_" + str(args.base_idx) + "_checkpoint.bin"
    model.load_state_dict(torch.load(modeldir+modelname))
    
    simplex_model = BasicSimplex(model, num_vertices=1, fixed_points=[False]).cuda()
    del model

    ## add a new points and train ##
    for vv in range(1, args.n_verts+1):
        simplex_model.add_vert()
        simplex_model = simplex_model.cuda()
        optimizer = torch.optim.SGD(
            simplex_model.parameters(),
            lr=args.lr_init,
            momentum=0.9,
            weight_decay=args.wd
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                               T_max=args.epochs)
        criterion = torch.nn.CrossEntropyLoss()
        columns = ['vert', 'ep', 'lr', 'tr_loss', 
                   'tr_acc', 'te_loss', 'te_acc', 'time']
        for epoch in range(args.epochs):
            time_ep = time.time()
            train_res = simp_utils.train_transformer_epoch(
                trainloader, 
                simplex_model, 
                criterion,
                optimizer,
                args.n_sample,
                vol_reg=1e-4,
                gradient_accumulation_steps=args.gradient_accumulation_steps,
            )

            start_ep = (epoch == 0)
            eval_ep = epoch % args.eval_freq == args.eval_freq - 1
            end_ep = epoch == args.epochs - 1
            # test_res = {'loss': None, 'accuracy': None}
            if eval_ep:
                test_res = simp_utils.eval(testloader, simplex_model, criterion)
            else:
                test_res = {'loss': None, 'accuracy': None}

            time_ep = time.time() - time_ep

            lr = optimizer.param_groups[0]['lr']
            scheduler.step()

            values = [vv, epoch + 1, lr, 
                      train_res['loss'], train_res['accuracy'], 
                      test_res['loss'], test_res['accuracy'], time_ep]

            table = tabulate.tabulate([values], columns, 
                                      tablefmt='simple', floatfmt='8.4f')
            if epoch % 40 == 0:
                table = table.split('\n')
                table = '\n'.join([table[1]] + table)
            else:
                table = table.split('\n')[2]
            print(table, flush=True)

        checkpoint = simplex_model.state_dict()
        fname = "lr_"+str(args.lr_init)+"simplex_vertex" + str(vv) + ".pt"
        torch.save(checkpoint, savedir + fname) 
Exemple #12
0
def train(local_rank, args, hp, model):

    if hp.train.ngpu > 1:
        dist.init_process_group(backend="nccl",
                                init_method="tcp://localhost:54321",
                                world_size=hp.train.ngpu,
                                rank=local_rank)

    torch.cuda.manual_seed(hp.train.seed)
    device = torch.device('cuda:{:d}'.format(local_rank))
    model = model.to(device)
    """ Train the model """
    if local_rank in [-1, 0]:
        os.makedirs(hp.data.outdir, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join("logs", args.name))
        print("Loading dataset :")

    hp.train.batch = hp.train.batch // hp.train.accum_grad

    # Prepare dataset
    train_loader, test_loader = get_loader(local_rank, hp)

    # Prepare optimizer and scheduler

    optimizer = torch.optim.AdamW(model.parameters(),
                                  hp.train.lr,
                                  betas=[0.8, 0.99],
                                  weight_decay=0.05)
    t_total = hp.train.num_steps
    if hp.train.decay_type == "cosine":
        scheduler = WarmupCosineSchedule(optimizer,
                                         warmup_steps=hp.train.warmup_steps,
                                         t_total=t_total)
    else:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=hp.train.warmup_steps,
                                         t_total=t_total)

    # Distributed training
    if hp.train.ngpu > 1:
        model = DDP(model, device_ids=[local_rank])

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", hp.train.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d", hp.train.batch)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        hp.train.batch * hp.train.accum_grad *
        (hp.train.ngpu if local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", hp.train.accum_grad)

    model.zero_grad()
    set_seed(
        hp)  # Added here for reproducibility (even between python 2 and 3)
    losses = AverageMeter()
    global_step, best_acc = 0, 0

    loss_fct = torch.nn.CrossEntropyLoss()
    while True:
        model.train()
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True,
                              disable=local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(device) for t in batch)
            x, y = batch
            logits = model(x)

            loss = loss_fct(logits.view(-1, hp.model.num_classes), y.view(-1))

            if hp.train.accum_grad > 1:
                loss = loss / hp.train.accum_grad
            loss.backward()

            if (step + 1) % hp.train.accum_grad == 0:
                losses.update(loss.item() * hp.train.accum_grad)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hp.train.grad_clip)

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d / %d Steps) (loss=%2.5f)" %
                    (global_step, t_total, losses.val))
                if local_rank in [-1, 0]:
                    writer.add_scalar("train/loss",
                                      scalar_value=losses.val,
                                      global_step=global_step)
                    writer.add_scalar("train/lr",
                                      scalar_value=scheduler.get_lr()[0],
                                      global_step=global_step)
                if global_step % hp.train.valid_step == 0 and local_rank in [
                        -1, 0
                ]:
                    accuracy = valid(device, local_rank, hp, model, writer,
                                     test_loader, global_step)
                    if best_acc < accuracy:
                        save_model(args.name, hp.data.outdir, model)
                        best_acc = accuracy
                    model.train()

                if global_step % t_total == 0:
                    break
        losses.reset()
        if global_step % t_total == 0:
            break

    if local_rank in [-1, 0]:
        writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")