def main():
    # Hyper Parameters
    
    opt = opts.parse_opt()

    device_id = opt.gpuid
    device_count = len(str(device_id).split(","))
    #assert device_count == 1 or device_count == 2
    print("use GPU:", device_id, "GPUs_count", device_count, flush=True)
    os.environ['CUDA_VISIBLE_DEVICES']=str(device_id)
    device_id = 0
    torch.cuda.set_device(0)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(
        opt.data_name, vocab, opt.batch_size, opt.workers, opt)

    # Construct the model
    model = SCAN(opt)
    model.cuda()
    model = nn.DataParallel(model)

     # Loss and Optimizer
    criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation)
    mse_criterion = nn.MSELoss(reduction="batchmean")
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate)

    # optionally resume from a checkpoint
    if not os.path.exists(opt.model_name):
        os.makedirs(opt.model_name)
    start_epoch = 0
    best_rsum = 0

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
                  .format(opt.resume, start_epoch, best_rsum))
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    evalrank(model.module, val_loader, opt)

    print(opt, flush=True)
    
    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name)
        log_file = os.path.join(opt.logger_name, "performance.log")
        logging_func(log_file, message)
        print("model name: ", opt.model_name, flush=True)
        adjust_learning_rate(opt, optimizer, epoch)
        run_time = 0
        for i, (images, captions, lengths, masks, ids, _) in enumerate(train_loader):
            start_time = time.time()
            model.train()

            optimizer.zero_grad()

            if device_count != 1:
                images = images.repeat(device_count,1,1)

            score = model(images, captions, lengths, masks, ids)
            loss = criterion(score)

            loss.backward()
            if opt.grad_clip > 0:
                clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
            run_time += time.time() - start_time
            # validate at every val_step
            if i % 100 == 0:
                log = "epoch: %d; batch: %d/%d; loss: %.4f; time: %.4f" % (epoch, 
                            i, len(train_loader), loss.data.item(), run_time / 100)
                print(log, flush=True)
                run_time = 0
            if (i + 1) % opt.val_step == 0:
                evalrank(model.module, val_loader, opt)

        print("-------- performance at epoch: %d --------" % (epoch))
        # evaluate on validation set
        rsum = evalrank(model.module, val_loader, opt)
        #rsum = -100
        filename = 'model_' + str(epoch) + '.pth.tar'
        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'best_rsum': best_rsum,
            'opt': opt,
        }, is_best, filename=filename, prefix=opt.model_name + '/')
예제 #2
0
파일: trainer.py 프로젝트: Willamjie/SCAN-1
    def train(self):
        model = SCAN(self.params)
        model.apply(init_xavier)
        model.load_state_dict(torch.load('models/model_weights_5.t7'))
        loss_function = MarginLoss(self.params.margin)
        if torch.cuda.is_available():
            model = model.cuda()
            loss_function = loss_function.cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=self.params.lr,
                                     weight_decay=self.params.wdecay)
        try:
            prev_best = 0
            for epoch in range(self.params.num_epochs):
                iters = 1
                losses = []
                start_time = timer()
                num_of_mini_batches = len(
                    self.data_loader.train_ids) // self.params.batch_size
                for (caption, mask, image, neg_cap, neg_mask,
                     neg_image) in tqdm(self.data_loader.training_data_loader):

                    # Sample according to hard negative mining
                    caption, mask, image, neg_cap, neg_mask, neg_image = self.data_loader.hard_negative_mining(
                        model, caption, mask, image, neg_cap, neg_mask,
                        neg_image)
                    model.train()
                    optimizer.zero_grad()
                    # forward pass.
                    similarity = model(to_variable(caption), to_variable(mask),
                                       to_variable(image), False)
                    similarity_neg_1 = model(to_variable(neg_cap),
                                             to_variable(neg_mask),
                                             to_variable(image), False)
                    similarity_neg_2 = model(to_variable(caption),
                                             to_variable(mask),
                                             to_variable(neg_image), False)

                    # Compute the loss, gradients, and update the parameters by calling optimizer.step()
                    loss = loss_function(similarity, similarity_neg_1,
                                         similarity_neg_2)
                    loss.backward()
                    losses.append(loss.data.cpu().numpy())
                    if self.params.clip_value > 0:
                        torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      self.params.clip_value)
                    optimizer.step()

                    #                     sys.stdout.write("[%d/%d] :: Training Loss: %f   \r" % (
                    #                         iters, num_of_mini_batches, np.asscalar(np.mean(losses))))
                    #                     sys.stdout.flush()
                    iters += 1

                if epoch + 1 % self.params.step_size == 0:
                    optim_state = optimizer.state_dict()
                    optim_state['param_groups'][0]['lr'] = optim_state[
                        'param_groups'][0]['lr'] / self.params.gamma
                    optimizer.load_state_dict(optim_state)

                torch.save(
                    model.state_dict(), self.params.model_dir +
                    '/model_weights_{}.t7'.format(epoch + 1))

                # Calculate r@k after each epoch
                if (epoch + 1) % self.params.validate_every == 0:
                    r_at_1, r_at_5, r_at_10 = self.evaluator.recall(
                        model, is_test=False)

                    print(
                        "Epoch {} : Training Loss: {:.5f}, R@1 : {}, R@5 : {}, R@10 : {}, Time elapsed {:.2f} mins"
                        .format(epoch + 1, np.asscalar(np.mean(losses)),
                                r_at_1, r_at_5, r_at_10,
                                (timer() - start_time) / 60))
                    if r_at_1 > prev_best:
                        print("Recall at 1 increased....saving weights !!")
                        prev_best = r_at_1
                        torch.save(
                            model.state_dict(), self.params.model_dir +
                            'best_model_weights_{}.t7'.format(epoch + 1))
                else:
                    print("Epoch {} : Training Loss: {:.5f}".format(
                        epoch + 1, np.asscalar(np.mean(losses))))
        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(model.state_dict(),
                       self.params.model_dir + '/model_weights_interrupt.t7')
예제 #3
0
파일: train.py 프로젝트: WangBenHui/CRGN
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_path',
        default='/data3/zhangyf/cross_modal_retrieval/SCAN/data',
        help='path to datasets')
    parser.add_argument('--data_name',
                        default='f30k_precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument(
        '--vocab_path',
        default='/data3/zhangyf/cross_modal_retrieval/SCAN/vocab/',
        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=20,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--decoder_dim',
                        default=512,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=10,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=30,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument(
        '--resume',
        default=
        '/data3/zhangyf/cross_modal_retrieval/vsepp_next_train_12_31_f30k/run/coco_vse++_ft_128_f30k_next/model_best.pth.tar',
        type=str,
        metavar='PATH',
        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--reset_train',
                        action='store_true',
                        help='Ensure the training is always done in '
                        'train mode (Not recommended).')
    parser.add_argument('--finetune',
                        action='store_true',
                        help='Fine-tune the image encoder.')
    parser.add_argument('--cnn_type',
                        default='resnet152',
                        help="""The CNN used for image encoder
                        (e.g. vgg19, resnet152)""")
    parser.add_argument('--crop_size',
                        default=224,
                        type=int,
                        help='Size of an image crop as the CNN input.')

    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = pickle.load(
        open(os.path.join(opt.vocab_path, '%s_vocab.pkl' % opt.data_name),
             'rb'))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        bset_rsum = train(opt, train_loader, model, epoch, val_loader,
                          best_rsum)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
예제 #4
0
파일: train.py 프로젝트: sungjune-p/SCAN
def main():
    # Hyper Parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',
                        default='./data/',
                        help='path to datasets')
    parser.add_argument('--data_name',
                        default='precomp',
                        help='{coco,f30k}_precomp')
    parser.add_argument('--vocab_path',
                        default='./vocab/',
                        help='Path to saved vocabulary json files.')
    parser.add_argument('--margin',
                        default=0.2,
                        type=float,
                        help='Rank loss margin.')
    parser.add_argument('--num_epochs',
                        default=30,
                        type=int,
                        help='Number of training epochs.')
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help='Size of a training mini-batch.')
    parser.add_argument('--word_dim',
                        default=300,
                        type=int,
                        help='Dimensionality of the word embedding.')
    parser.add_argument('--embed_size',
                        default=1024,
                        type=int,
                        help='Dimensionality of the joint embedding.')
    parser.add_argument('--grad_clip',
                        default=2.,
                        type=float,
                        help='Gradient clipping threshold.')
    parser.add_argument('--num_layers',
                        default=1,
                        type=int,
                        help='Number of GRU layers.')
    parser.add_argument('--learning_rate',
                        default=.0002,
                        type=float,
                        help='Initial learning rate.')
    parser.add_argument('--lr_update',
                        default=15,
                        type=int,
                        help='Number of epochs to update the learning rate.')
    parser.add_argument('--workers',
                        default=10,
                        type=int,
                        help='Number of data loader workers.')
    parser.add_argument('--log_step',
                        default=10,
                        type=int,
                        help='Number of steps to print and record the log.')
    parser.add_argument('--val_step',
                        default=500,
                        type=int,
                        help='Number of steps to run validation.')
    parser.add_argument('--logger_name',
                        default='./runs/runX/log',
                        help='Path to save Tensorboard log.')
    parser.add_argument('--model_name',
                        default='./runs/runX/checkpoint',
                        help='Path to save the model.')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--max_violation',
                        action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--img_dim',
                        default=2048,
                        type=int,
                        help='Dimensionality of the image embedding.')
    parser.add_argument('--no_imgnorm',
                        action='store_true',
                        help='Do not normalize the image embeddings.')
    parser.add_argument('--no_txtnorm',
                        action='store_true',
                        help='Do not normalize the text embeddings.')
    parser.add_argument(
        '--raw_feature_norm',
        default="clipped_l2norm",
        help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
    parser.add_argument('--agg_func',
                        default="LogSumExp",
                        help='LogSumExp|Mean|Max|Sum')
    parser.add_argument('--cross_attn', default="t2i", help='t2i|i2t')
    parser.add_argument('--precomp_enc_type',
                        default="basic",
                        help='basic|weight_norm')
    parser.add_argument('--bi_gru',
                        action='store_true',
                        help='Use bidirectional GRU.')
    parser.add_argument('--lambda_lse',
                        default=6.,
                        type=float,
                        help='LogSumExp temp.')
    parser.add_argument('--lambda_softmax',
                        default=9.,
                        type=float,
                        help='Attention softmax temperature.')
    opt = parser.parse_args()
    print(opt)

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    tb_logger.configure(opt.logger_name, flush_secs=5)

    # Load Vocabulary Wrapper
    vocab = deserialize_vocab(
        os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data.get_loaders(opt.data_name, vocab,
                                                opt.batch_size, opt.workers,
                                                opt)

    # Construct the model
    model = SCAN(opt)

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch']
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    best_rsum = 0
    for epoch in range(opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)

        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'best_rsum': best_rsum,
                'opt': opt,
                'Eiters': model.Eiters,
            },
            is_best,
            filename='checkpoint_{}.pth.tar'.format(epoch),
            prefix=opt.model_name + '/')
예제 #5
0
def start_experiment(opt, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    print("Let's use", torch.cuda.device_count(), "GPUs!")
    print("Number threads:", torch.get_num_threads())

    # Load Vocabulary Wrapper, create dictionary that can switch between ids and words
    vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format(
        opt.vocab_path, opt.clothing, opt.data_name, opt.version))

    opt.vocab_size = len(vocab)

    # Load data loaders
    train_loader, val_loader = data_ken.get_loaders(opt.data_name, vocab,
                                                    opt.batch_size,
                                                    opt.workers, opt)

    # Construct the model
    model = SCAN(opt)

    # save hyperparameters in file
    save_hyperparameters(opt.logger_name, opt)

    best_rsum = 0
    start_epoch = 0
    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_rsum = checkpoint['best_rsum']
            model.load_state_dict(checkpoint['model'])
            # Eiters is used to show logs as the continuation of another
            # training
            model.Eiters = checkpoint['Eiters']
            print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format(
                opt.resume, start_epoch, best_rsum))
            validate(opt, val_loader, model)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # Train the Model
    for epoch in range(start_epoch, opt.num_epochs):
        print(opt.logger_name)
        print(opt.model_name)
        adjust_learning_rate(opt, model.optimizer, epoch)

        # train for one epoch
        train(opt, train_loader, model, epoch, val_loader)

        # evaluate on validation set
        rsum = validate(opt, val_loader, model)

        # remember best R@ sum and save checkpoint
        is_best = rsum > best_rsum
        best_rsum = max(rsum, best_rsum)
        if not os.path.exists(opt.model_name):
            os.mkdir(opt.model_name)

        last_epoch = False
        if epoch == (opt.num_epochs - 1):
            last_epoch = True

        # only save when best epoch, or last epoch for further training
        if is_best or last_epoch:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model': model.state_dict(),
                    'best_rsum': best_rsum,
                    'opt': opt,
                    'Eiters': model.Eiters,
                },
                is_best,
                last_epoch,
                filename='checkpoint_{}.pth.tar'.format(epoch),
                prefix=opt.model_name + '/')
    return best_rsum