예제 #1
0
def train(train_loader, val_loader, epochnum, save_path='.', save_freq=None):
    iter_size = len(train_loader)
    net = Encoder()
    net.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=2e-4)

    for epoch in range(epochnum):
        print('epoch : {}'.format(epoch))
        net.train()
        train_loss = 0
        train_correct = 0
        total = 0
        net.training = True
        for i, data in enumerate(train_loader):
            sys.stdout.write('iter : {} / {}\r'.format(i, iter_size))
            sys.stdout.flush()
            #print('iter: {} / {}'.format(i, iter_size))
            inputs, labels = data
            inputs, labels = Variable(inputs.cuda()), labels.cuda()
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, Variable(labels))
            loss.backward()
            optimizer.step()
            train_loss += loss.data[0]
            pred = (torch.max(outputs.data, 1)[1])
            train_correct += (pred == labels).sum()
            total += labels.size(0)
        sys.stdout.write(' ' * 20 + '\r')
        sys.stdout.flush()

        print('train_loss:{}, train_acc:{:.2%}'.format(train_loss / total,
                                                       train_correct / total))
        val_loss = 0
        val_correct = 0
        total = 0
        net.training = False
        for data in val_loader:
            net.eval()
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), labels.cuda()
            outputs = net(inputs)
            pred = torch.max(outputs.data, 1)[1]
            total += labels.size(0)
            loss = criterion(outputs, Variable(labels))
            val_loss += loss.data[0]
            val_correct += (pred == labels).sum()

        print('val_loss:{}, val_acc:{:.2%}'.format(val_loss / total,
                                                   val_correct / total))
        optimizer.param_groups[0]['lr'] *= np.exp(-0.4)
        if save_freq and epoch % save_freq == save_freq - 1:
            net_name = os.path.join(save_path, 'epoch_{}'.format(epoch))
            torch.save(net, net_name)
    torch.save(net, os.path.join(save_path, 'trained_net'))
예제 #2
0
class PretrainingTrainer:
    def __init__(self):
        self.preprocessor = None
        self.model = None
        self.optimizer = None

    def setup_preprocessed_data(self):
        self.preprocessor = Preprocess()
        self.preprocessor.setup()

    def setup_model(self):
        # Create multilingual vocabulary
        self.model = Encoder()

        if con.CUDA:
            self.model = self.model.cuda()

    def setup_scheduler_optimizer(self):
        lr_rate = 0.001
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=lr_rate,
                                    weight_decay=0)

    def train_model(self):
        train_loader = self.preprocessor.train_loaders
        batch_size = 8

        self.model.train()
        train_loss = 0
        batch_correct = 0
        total_correct = 0
        index = 0
        for hrl_src, lrl_src, hrl_att, lrl_att in train_loader:
            logits = self.model(hrl_src)
            print(logits.shape)
            break
            # self.optimizer.zero_grad()
            # batch_loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
            # self.optimizer.step()
            # batch_correct += self.evaluate(masked_outputs=masked_outputs, masked_lm_ids=masked_lm_ids)
            # total_correct += (8 * 20)

    def run_pretraining(self):
        self.setup_preprocessed_data()
        self.setup_model()
        self.setup_scheduler_optimizer()
        self.train_model()
예제 #3
0
def main(args):

    # ==============================
    # Create some folders or files for saving
    # ==============================

    if not os.path.exists(args.root_folder):
        os.mkdir(args.root_folder)

    loss_path = args.loss_path
    mertics_path = args.mertics_path
    epoch_model_path = args.epoch_model_path
    best_model_path = args.best_model_path
    generated_captions_path = args.generated_captions_folder_path
    sentences_show_path = args.sentences_show_path

    # Transform the format of images
    # This function in utils.general_tools.py
    train_transform = get_train_transform()
    val_transform = get_val_trainsform()

    # Load vocabulary
    print("*** Load Vocabulary ***")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Create data sets
    # This function in data_load.py
    train_data = train_load(root=args.train_image_dir,
                            json=args.train_caption_path,
                            vocab=vocab,
                            transform=train_transform,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    val_data = val_load(root=args.val_image_dir,
                        json=args.val_caption_path,
                        transform=val_transform,
                        batch_size=1,
                        shuffle=False,
                        num_workers=args.num_workers)

    # Build model
    encoder = Encoder(args.hidden_dim, args.fine_tuning).to(device)
    decoder = Decoder(args.embedding_dim, args.hidden_dim, vocab, len(vocab),
                      args.max_seq_length).to(device)

    # Select loss function
    criterion = nn.CrossEntropyLoss().to(device)

    if args.fine_tuning == True:
        params = list(decoder.parameters()) + list(encoder.parameters())
        optimizer = torch.optim.Adam(params, lr=args.fine_tuning_lr)
    else:
        params = decoder.parameters()
        optimizer = torch.optim.Adam(params, lr=args.fine_tuning_lr)

    # Load pretrained model
    if args.resume == True:
        checkpoint = torch.load(best_model_path)
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        if args.fine_tuning == False:
            optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        best_score = checkpoint['best_score']
        best_epoch = checkpoint['best_epoch']

    # New epoch and score
    else:
        start_epoch = 1
        best_score = 0
        best_epoch = 0

    for epoch in range(start_epoch, 10000):

        print("-" * 20)
        print("epoch:{}".format(epoch))

        # Adjust learning rate when the difference between epoch and best epoch is multiple of 3
        if (epoch - best_epoch) > 0 and (epoch - best_epoch) % 4 == 0:
            # This function in utils.general_tools.py
            adjust_lr(optimizer, args.shrink_factor)
        if (epoch - best_epoch) > 10:
            break
            print("*** Training complete ***")

        # =============
        # Training
        # =============

        print(" *** Training ***")
        decoder.train()
        encoder.train()
        total_step = len(train_data)
        epoch_loss = 0
        for (images, captions, lengths, img_ids) in tqdm(train_data):
            images = images.to(device)
            captions = captions.to(device)
            # Why do lengths cut 1 and the first dimension of captions from 1
            # Because we need to ignore the begining symbol <start>
            lengths = list(np.array(lengths) - 1)

            targets = pack_padded_sequence(captions[:, 1:],
                                           lengths,
                                           batch_first=True)[0]
            features = encoder(images)
            predictions = decoder(features, captions, lengths)
            predictions = pack_padded_sequence(predictions,
                                               lengths,
                                               batch_first=True)[0]

            loss = criterion(predictions, targets)
            epoch_loss += loss.item()
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

        # Save loss information
        # This function in utils.save_tools.py
        save_loss(round(epoch_loss / total_step, 3), epoch, loss_path)

        # =============
        # Evaluating
        # =============

        print("*** Evaluating ***")
        encoder.eval()
        decoder.eval()
        generated_captions = []
        for image, img_id in tqdm(val_data):

            image = image.to(device)
            img_id = img_id[0]

            features = encoder(image)
            sentence = decoder.generate(features)
            sentence = ' '.join(sentence)
            item = {'image_id': int(img_id), 'caption': sentence}
            generated_captions.append(item)
            j = random.randint(1, 100)

        print('*** Computing metrics ***')

        # Save current generated captions
        # This function in utils.save_tools.py

        captions_json_path = save_generated_captions(generated_captions, epoch,
                                                     generated_captions_path,
                                                     args.fine_tuning)

        # Compute score of metrics
        # This function in utils.general_tools.py
        results = coco_metrics(args.val_caption_path, captions_json_path,
                               epoch, sentences_show_path)

        # Save metrics results
        # This function in utils.save_tools.py
        epoch_score = save_metrics(results, epoch, mertics_path)

        # Update the best score
        if best_score < epoch_score:

            best_score = epoch_score
            best_epoch = epoch

            save_best_model(encoder, decoder, optimizer, epoch, best_score,
                            best_epoch, best_model_path)

        print("*** Best score:{} Best epoch:{} ***".format(
            best_score, best_epoch))
        # Save every epoch model
        save_epoch_model(encoder, decoder, optimizer, epoch, best_score,
                         best_epoch, epoch_model_path, args.fine_tuning)
예제 #4
0
def main():
    # Create model directory
    ##### arguments #####
    PATH = os.getcwd()
    image_dir = './data/resized2014/'
    caption_path = './data/annotations/captions_train2014.json'
    vocab_path = './data/vocab.pkl'
    model_path = './model'
    crop_size = 224
    batch_size = 128
    num_workers = 4
    learning_rate = 0.001

    # Decoder
    embed_size = 512
    hidden_size = 512
    num_layers = 3  # number of lstm layers
    num_epochs = 10
    start_epoch = 0
    save_step = 3000

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    coco = CocoDataset(image_dir, caption_path, vocab, transform)
    dataLoader = torch.utils.data.DataLoader(coco,
                                             batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             collate_fn=coco_batch)

    # Declare the encoder decoder
    encoder = Encoder(embed_size=embed_size).to(device)
    decoder = Decoder(embed_size=embed_size,
                      hidden_size=hidden_size,
                      vocab_size=len(vocab),
                      num_layers=num_layers).to(device)

    encoder.train()
    decoder.train()
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())
    # For encoder only train the last fc layer
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    # Train the models
    total_step = len(dataLoader)
    for epoch in range(num_epochs):
        for i, (images, captions, lengths) in enumerate(dataLoader):
            # Set mini-batch dataset
            images = images.cuda()
            captions = captions.cuda()
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()

            for group in optimizer.param_groups:
                for p in group['params']:
                    state = optimizer.state[p]
                    if ('step' in state and state['step'] >= 1024):
                        state['step'] = 1000

            loss.backward(retain_graph=True)
            optimizer.step()

            # Print log info
            if i % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1 + start_epoch, num_epochs + start_epoch, i,
                    total_step, loss.item()))

            # Save the model checkpoints
            if (i + 1) % save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(
                        model_path,
                        'decoder-{}-{}.ckpt'.format(epoch + 1 + start_epoch,
                                                    i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(
                        model_path,
                        'encoder-{}-{}.ckpt'.format(epoch + 1 + start_epoch,
                                                    i + 1)))

        print('epoch ', epoch + 1, 'loss: ', loss.item())
예제 #5
0
print("accuracy in intial train of DCD:{}".format(acc))

#DCDとg,hのadversarial training
print("part 3 : adversarial training of g&h and DCD")
dcd_test_acc = []  #結果保存用の配列
cls_s_test_acc = []
cls_t_test_acc = []

for epoch in range(num_ep_train):
    #optimizer_g = torch.optim.Adam(net_g.parameters(), lr = 0.001)
    optimizer_g_h = torch.optim.Adam(list(net_g.parameters()) +
                                     list(net_h.parameters()),
                                     lr=0.001)
    optimizer_DCD = torch.optim.Adam(net_DCD.parameters(), lr=0.001)
    #---------DCDを固定しg,hをtrain---------------
    net_g.train().to(device)
    net_h.train().to(device)
    net_DCD.to(device)
    optimizer_g_h.zero_grad()
    #G2とG4をロード
    G2_G4_alterd, label = G2_G4_loader(s_trainset,
                                       t_trainset,
                                       net_g,
                                       device=device)
    G2_G4_alterd, label = G2_G4_alterd.to(device), label.to(device)
    #DCDに識別されるロスを計算
    dcd_pred = net_DCD(G2_G4_alterd)
    loss_dcd = loss_func(dcd_pred, label)
    #分類ロスを計算
    # ソースにおける分類ロス
    s_sample = random.choices(s_trainset, k=adv_gh_datanum)
예제 #6
0
파일: train.py 프로젝트: v-juma1/textent
def train(description_db, entity_db, word_vocab, entity_vocab,
          target_entity_vocab, out_file, embeddings, dim_size, batch_size,
          negative, epoch, optimizer, max_text_len, max_entity_len, pool_size,
          seed, save, **model_params):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    word_matrix = np.random.uniform(low=-0.05,
                                    high=0.05,
                                    size=(word_vocab.size, dim_size))
    word_matrix = np.vstack([np.zeros(dim_size),
                             word_matrix]).astype('float32')

    entity_matrix = np.random.uniform(low=-0.05,
                                      high=0.05,
                                      size=(entity_vocab.size, dim_size))
    entity_matrix = np.vstack([np.zeros(dim_size),
                               entity_matrix]).astype('float32')

    target_entity_matrix = np.random.uniform(low=-0.05,
                                             high=0.05,
                                             size=(target_entity_vocab.size,
                                                   dim_size))
    target_entity_matrix = np.vstack(
        [np.zeros(dim_size), target_entity_matrix]).astype('float32')

    for embedding in embeddings:
        for word in word_vocab:
            vec = embedding.get_word_vector(word)
            if vec is not None:
                word_matrix[word_vocab.get_index(word)] = vec

        for title in entity_vocab:
            vec = embedding.get_entity_vector(title)
            if vec is not None:
                entity_matrix[entity_vocab.get_index(title)] = vec

        for title in target_entity_vocab:
            vec = embedding.get_entity_vector(title)
            if vec is not None:
                target_entity_matrix[target_entity_vocab.get_index(
                    title)] = vec

    entity_negatives = np.arange(1, target_entity_matrix.shape[0])

    model_params.update(dict(dim_size=dim_size))
    model = Encoder(word_embedding=word_matrix,
                    entity_embedding=entity_matrix,
                    target_entity_embedding=target_entity_matrix,
                    word_vocab=word_vocab,
                    entity_vocab=entity_vocab,
                    target_entity_vocab=target_entity_vocab,
                    **model_params)

    del word_matrix
    del entity_matrix
    del target_entity_matrix

    model = model.cuda()

    model.train()
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer_ins = getattr(optim, optimizer)(parameters)

    n_correct = 0
    n_total = 0
    cur_correct = 0
    cur_total = 0
    cur_loss = 0.0

    batch_idx = 0

    joblib.dump(
        dict(model_params=model_params,
             word_vocab=word_vocab.serialize(),
             entity_vocab=entity_vocab.serialize(),
             target_entity_vocab=target_entity_vocab.serialize()),
        out_file + '.pkl')

    if not save or 0 in save:
        state_dict = model.state_dict()
        torch.save(state_dict, out_file + '_epoch0.bin')

    for n_epoch in range(1, epoch + 1):
        logger.info('Epoch: %d', n_epoch)

        for (batch_idx, (args, target)) in enumerate(
                generate_data(description_db, word_vocab, entity_vocab,
                              target_entity_vocab, entity_negatives,
                              batch_size, negative, max_text_len,
                              max_entity_len, pool_size), batch_idx):
            args = tuple([o.cuda(async=True) for o in args])
            target = target.cuda()

            optimizer_ins.zero_grad()
            output = model(args)
            loss = F.cross_entropy(output, target)
            loss.backward()

            optimizer_ins.step()

            cur_correct += (torch.max(output, 1)[1].view(
                target.size()).data == target.data).sum()
            cur_total += len(target)
            cur_loss += loss.data
            if batch_idx != 0 and batch_idx % 1000 == 0:
                n_correct += cur_correct
                n_total += cur_total
                logger.info(
                    'Processed %d batches (epoch: %d, loss: %.4f acc: %.4f total acc: %.4f)'
                    % (batch_idx, n_epoch, cur_loss[0] / cur_total, 100. *
                       cur_correct / cur_total, 100. * n_correct / n_total))
                cur_correct = 0
                cur_total = 0
                cur_loss = 0.0
예제 #7
0
파일: main.py 프로젝트: mandiehyewon/AI502
	## true prior is random normal (randn)
	## this is constraining the Z-projection to be normal!
	Enc.eval()
	z_real = Variable(torch.randn(images.size()[0], latentspace_dim) * 5.).cuda()
	D_real = Discrim(z_real)

	z_fake = Enc(images)
	D_fake = Discrim(z_fake)

	D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

	D_loss.backward()
	optim_Dec.step()

	# Generator
	Enc.train()
	z_fake = Enc(images)
	D_fake = Discrim(z_fake)
	
	G_loss = -torch.mean(torch.log(D_fake))

	G_loss.backward()
	optim_Enc_gen.step()   

	
	if (step+1) % 100 == 0:
		# print ('Step [%d/%d], Loss: %.4f, Acc: %.2f' 
		#		%(step+1, total_step, loss.data[0], accuracy.data[0]))

		#============ TensorBoard logging ============#
		# (1) Log the scalar values
예제 #8
0
class Image_Captioning:
    def __init__(self):
        parser = argparse.ArgumentParser(description='Image Captioning')
        parser.add_argument('--root',
                            default='../../../cocodataset/',
                            type=str)
        parser.add_argument('--crop_size', default=224, type=int)
        parser.add_argument('--epochs', default=100, type=int)
        parser.add_argument('--lr', default=1e-4, type=float)
        parser.add_argument('--batch_size', default=128, help='')
        parser.add_argument('--num_workers', default=4, type=int)
        parser.add_argument('--embed_dim', default=256, type=int)
        parser.add_argument('--hidden_size', default=512, type=int)
        parser.add_argument('--num_layers', default=1, type=int)
        parser.add_argument('--model_path', default='./model/', type=str)
        parser.add_argument('--vocab_path', default='./vocab/', type=str)
        parser.add_argument('--save_step', default=1000, type=int)

        self.args = parser.parse_args()
        self.Multi_GPU = False

        # if torch.cuda.device_count() > 1:
        #     print('Multi GPU Activate!')
        #     print('Using GPU :', int(torch.cuda.device_count()))
        #     self.Multi_GPU = True

        os.makedirs(self.args.model_path, exist_ok=True)

        transform = transforms.Compose([
            transforms.RandomCrop(self.args.crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        with open(self.args.vocab_path + 'vocab.pickle', 'rb') as f:
            data = pickle.load(f)

        self.vocab = data

        self.DataLoader = get_dataloader(root=self.args.root,
                                         transform=transform,
                                         shuffle=True,
                                         batch_size=self.args.batch_size,
                                         num_workers=self.args.num_workers,
                                         vocab=self.vocab)

        self.Encoder = Encoder(embed_dim=self.args.embed_dim)
        self.Decoder = Decoder(embed_dim=self.args.embed_dim,
                               hidden_size=self.args.hidden_size,
                               vocab_size=len(self.vocab),
                               num_layers=self.args.num_layers)
        # print(self.Encoder)
        # print(self.Decoder)

    def train(self):
        if self.Multi_GPU:
            self.Encoder = torch.nn.DataParallel(self.Encoder)
            self.Decoder = torch.nn.DataParallel(self.Decoder)
            parameters = list(self.Encoder.module.fc.parameters()) + list(
                self.Encoder.module.BN.parameters()) + list(
                    self.Decoder.parameters())
        else:
            parameters = list(self.Encoder.fc.parameters()) + list(
                self.Encoder.BN.parameters()) + list(self.Decoder.parameters())

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(parameters, lr=self.args.lr)

        self.Encoder.cuda()
        self.Decoder.cuda()

        self.Encoder.train()
        self.Decoder.train()

        print('-' * 100)
        print('Now Training')
        print('-' * 100)

        for epoch in range(self.args.epochs):
            total_loss = 0
            for batch_idx, (image, captions,
                            lengths) in enumerate(self.DataLoader):
                optimizer.zero_grad()
                image, captions = image.cuda(), captions.cuda()

                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                if self.Multi_GPU:
                    img_features = nn.parallel.DataParallel(
                        self.Encoder, image)
                    outputs = nn.parallel.DataParallel(
                        self.Decoder, (img_features, captions, lengths))
                else:
                    img_features = self.Encoder(image)
                    outputs = self.Decoder(img_features, captions, lengths)

                loss = criterion(outputs, targets)
                total_loss += loss.item()

                loss.backward()
                optimizer.step()

                if batch_idx % 30 == 0:
                    print('Epoch : {}, Step : [{}/{}], Step Loss : {:.4f}'.
                          format(epoch, batch_idx, len(self.DataLoader),
                                 loss.item()))

            print('Epoch : [{}/{}], Total loss : {:.4f}'.format(
                epoch, self.args.epochs, total_loss / len(self.DataLoader)))

        print('Now saving the models')
        torch.save(
            self.Encoder.state_dict(),
            self.args.model_path + 'Encoder-{}.ckpt'.format(self.args.epochs))
        torch.save(
            self.Decoder.state_dict(),
            self.args.model_path + 'Decoder-{}.ckpt'.format(self.args.epochs))
예제 #9
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=10000,
        metavar='N',
        help='input batch size for reconstruction testing (default: 10,000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--store-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before storing training loss')

    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Set up dataloaders
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([transforms.ToTensor()])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    train_loader_eval = torch.utils.data.DataLoader(
        datasets.MNIST('../data',
                       train=True,
                       transform=transforms.Compose([transforms.ToTensor()])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **{})

    # Init model and optimizer
    model = Encoder(device).to(device)
    #Initialise weights and train
    path = "./output"

    #Initialise weights
    model.apply(weights_init)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #Get rotation loss in t
    # rotation_test_loss=rotation_test(args, model_encoder, 'cpu', test_loader_disc)
    rotation_test_loss = []
    train_loss = []
    test_loss = []

    # Where the magic happens
    for epoch in range(1, args.epochs + 1):
        for batch_idx, (data, targets) in enumerate(train_loader):
            model.train()
            # Reshape data
            targets, angles = rotate_tensor(data.numpy())
            targets = torch.from_numpy(targets).to(device)
            angles = torch.from_numpy(angles).to(device)
            angles = angles.view(angles.size(0), 1)

            # Forward passes
            data = data.to(device)
            optimizer.zero_grad()
            f_data = model(data)  # [N,2,1,1]
            f_targets = model(targets)  #[N,2,1,1]

            #Apply rotatin matrix to f_data with feature transformer
            f_data_trasformed = feature_transformer(f_data, angles, device)

            #Define Loss
            forb_distance = torch.nn.PairwiseDistance()
            loss = (forb_distance(f_data_trasformed.view(-1, 2),
                                  f_targets.view(-1, 2))**2).sum()

            # Backprop
            loss.backward()
            optimizer.step()

            #Log progress
            if batch_idx % args.log_interval == 0:
                sys.stdout.write(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\r'.format(
                        epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss))
                sys.stdout.flush()

            #Store training and test loss
            if batch_idx % args.store_interval == 0:
                #Train Lossq
                train_loss.append(
                    evaluate_model(model, device, train_loader_eval))

                #Test Loss
                test_loss.append(evaluate_model(model, device, test_loader))

                #Rotation loss
                rotation_test_loss.append(
                    rotation_test(model, device, test_loader))

    #Save model
    save_model(args, model)
    #Save losses
    train_loss = np.array(train_loss)
    test_loss = np.array(test_loss)
    rotation_test_loss = np.array(rotation_test_loss)

    np.save(path + '/training_loss', train_loss)
    np.save(path + '/test_loss', test_loss)
    np.save(path + '/rotation_test_loss', rotation_test_loss)

    plot_learning_curve(args, train_loss, test_loss, rotation_test_loss)
예제 #10
0
def main(args):
    model_prefix = '{}_{}'.format(args.model_type, args.train_id)

    log_path = args.LOG_DIR + model_prefix + '/'
    checkpoint_path = args.CHK_DIR + model_prefix + '/'
    result_path = args.RESULT_DIR + model_prefix + '/'
    cp_file = checkpoint_path + "best_model.pth.tar"
    init_epoch = 0

    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)

    ## set up the logger
    set_logger(os.path.join(log_path, 'train.log'))

    ## save argparse parameters
    with open(log_path + 'args.yaml', 'w') as f:
        for k, v in args.__dict__.items():
            f.write('{}: {}\n'.format(k, v))

    logging.info('Training model: {}'.format(model_prefix))

    ## set up vocab txt
    setup(args, clear=True)
    print(args.__dict__)

    # indicate src and tgt language
    src, tgt = 'en', 'zh'

    maps = {'en': args.TRAIN_VOCAB_EN, 'zh': args.TRAIN_VOCAB_ZH}
    vocab_src = read_vocab(maps[src])
    tok_src = Tokenizer(language=src,
                        vocab=vocab_src,
                        encoding_length=args.MAX_INPUT_LENGTH)
    vocab_tgt = read_vocab(maps[tgt])
    tok_tgt = Tokenizer(language=tgt,
                        vocab=vocab_tgt,
                        encoding_length=args.MAX_INPUT_LENGTH)
    logging.info('Vocab size src/tgt:{}/{}'.format(len(vocab_src),
                                                   len(vocab_tgt)))

    ## Setup the training, validation, and testing dataloaders
    train_loader, val_loader, test_loader = create_split_loaders(
        args.DATA_DIR, (tok_src, tok_tgt),
        args.batch_size,
        args.MAX_VID_LENGTH, (src, tgt),
        num_workers=4,
        pin_memory=True)
    logging.info('train/val/test size: {}/{}/{}'.format(
        len(train_loader), len(val_loader), len(test_loader)))

    ## init model
    if args.model_type == 's2s':
        encoder = Encoder(vocab_size=len(vocab_src),
                          embed_size=args.wordembed_dim,
                          hidden_size=args.enc_hid_size).cuda()
        decoder = Decoder(embed_size=args.wordembed_dim,
                          hidden_size=args.dec_hid_size,
                          vocab_size=len(vocab_tgt)).cuda()

    encoder.train()
    decoder.train()

    ## define loss
    criterion = nn.CrossEntropyLoss(ignore_index=padding_idx).cuda()
    ## init optimizer
    dec_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   decoder.parameters()),
                                     lr=args.decoder_lr,
                                     weight_decay=args.weight_decay)
    enc_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   encoder.parameters()),
                                     lr=args.encoder_lr,
                                     weight_decay=args.weight_decay)

    count_paras(encoder, decoder, logging)

    ## track loss during training
    total_train_loss, total_val_loss = [], []
    best_val_bleu, best_epoch = 0, 0

    ## init time
    zero_time = time.time()

    # Begin training procedure
    earlystop_flag = False
    rising_count = 0

    for epoch in range(init_epoch, args.epochs):
        ## train for one epoch
        start_time = time.time()
        train_loss = train(train_loader, encoder, decoder, criterion,
                           enc_optimizer, dec_optimizer, epoch)

        val_loss, sentbleu, corpbleu = validate(val_loader, encoder, decoder,
                                                criterion)
        end_time = time.time()

        epoch_time = end_time - start_time
        total_time = end_time - zero_time

        logging.info(
            'Total time used: %s Epoch %d time uesd: %s train loss: %.4f val loss: %.4f sentbleu: %.4f corpbleu: %.4f'
            % (str(datetime.timedelta(seconds=int(total_time))), epoch,
               str(datetime.timedelta(seconds=int(epoch_time))), train_loss,
               val_loss, sentbleu, corpbleu))

        if corpbleu > best_val_bleu:
            best_val_bleu = corpbleu
            save_checkpoint(
                {
                    'epoch': epoch,
                    'enc_state_dict': encoder.state_dict(),
                    'dec_state_dict': decoder.state_dict(),
                    'enc_optimizer': enc_optimizer.state_dict(),
                    'dec_optimizer': dec_optimizer.state_dict(),
                }, cp_file)
            best_epoch = epoch

        logging.info("Finished {0} epochs of training".format(epoch + 1))

        total_train_loss.append(train_loss)
        total_val_loss.append(val_loss)

    logging.info('Best corpus bleu score {:.4f} at epoch {}'.format(
        best_val_bleu, best_epoch))

    ### the best model is the last model saved in our implementation
    logging.info('************ Start eval... ************')
    eval(test_loader, encoder, decoder, cp_file, tok_tgt, result_path)
예제 #11
0
def main():
    input_lang, output_lang, pairs, data1, data2 = read_langs("eng", "fra", True)
    input_tensor = [[input_lang.word2index[s] for s in es.split(' ')] for es in data1]
    target_tensor = [[output_lang.word2index[s] for s in es.split(' ')] for es in data2]
    max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)

    input_tensor = [pad_sequences(x, max_length_inp) for x in input_tensor]
    target_tensor = [pad_sequences(x, max_length_tar) for x in target_tensor]
    print(len(target_tensor))

    input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor,
                                                                                                    target_tensor,
                                                                                                    test_size=0.2)

    # Show length
    print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))

    BUFFER_SIZE = len(input_tensor_train)
    BATCH_SIZE = 64
    N_BATCH = BUFFER_SIZE // BATCH_SIZE
    embedding_dim = 256
    units = 1024
    vocab_inp_size = len(input_lang.word2index)
    vocab_tar_size = len(output_lang.word2index)

    train_dataset = MyData(input_tensor_train, target_tensor_train)
    val_dataset = MyData(input_tensor_val, target_tensor_val)

    dataset = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                         drop_last=True,
                         shuffle=True)

    device = torch.device("cpu")

    encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
    decoder = Decoder(vocab_tar_size, embedding_dim, units, units, BATCH_SIZE)

    encoder.to(device)
    decoder.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
                           lr=0.001)

    EPOCHS = 10

    for epoch in range(EPOCHS):
        start = time()

        encoder.train()
        decoder.train()

        total_loss = 0

        for (batch, (inp, targ, inp_len)) in enumerate(dataset):
            loss = 0

            xs, ys, lens = sort_batch(inp, targ, inp_len)
            enc_output, enc_hidden = encoder(xs.to(device), lens, device)
            dec_hidden = enc_hidden
            dec_input = torch.tensor([[output_lang.word2index['<sos>']]] * BATCH_SIZE)

            for t in range(1, ys.size(1)):
                predictions, dec_hidden, _ = decoder(dec_input.to(device),
                                                     dec_hidden.to(device),
                                                     enc_output.to(device))
                loss += loss_function(criterion, ys[:, t].to(device), predictions.to(device))
                # loss += loss_
                dec_input = ys[:, t].unsqueeze(1)

            batch_loss = (loss / int(ys.size(1)))
            total_loss += batch_loss

            optimizer.zero_grad()

            loss.backward()

            ### UPDATE MODEL PARAMETERS
            optimizer.step()

            if batch % 100 == 0:
                print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                             batch,
                                                             batch_loss.detach().item()))

        ### TODO: Save checkpoint for model
        print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                            total_loss / N_BATCH))
        print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
예제 #12
0
def main(frequency, batch_size, epoch_num, verbose, MODE):
    mode = MODE
    word2index, index2word, word2vec, index2each, label_size_each, data_idx_each = load_data(
        frequency)
    (label_size, label_lexname_size, label_rootaffix_size,
     label_sememe_size) = label_size_each
    (data_train_idx, data_dev_idx, data_test_500_seen_idx,
     data_test_500_unseen_idx, data_defi_c_idx,
     data_desc_c_idx) = data_idx_each
    (index2sememe, index2lexname, index2rootaffix) = index2each
    index2word = np.array(index2word)
    test_dataset = MyDataset(data_test_500_seen_idx +
                             data_test_500_unseen_idx + data_desc_c_idx)
    valid_dataset = MyDataset(data_dev_idx)
    train_dataset = MyDataset(data_train_idx + data_defi_c_idx)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=my_collate_fn)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=my_collate_fn)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  collate_fn=my_collate_fn)

    print('DataLoader prepared. Batch_size [%d]' % batch_size)
    print('Train dataset: ', len(train_dataset))
    print('Valid dataset: ', len(valid_dataset))
    print('Test dataset: ', len(test_dataset))
    data_all_idx = data_train_idx + data_dev_idx + data_test_500_seen_idx + data_test_500_unseen_idx + data_defi_c_idx

    sememe_num = len(index2sememe)
    wd2sem = word2feature(
        data_all_idx, label_size, sememe_num, 'sememes'
    )  # label_size, not len(word2index). we only use target_words' feature
    wd_sems = label_multihot(wd2sem, sememe_num)
    wd_sems = torch.from_numpy(np.array(wd_sems)).to(
        device)  #torch.from_numpy(np.array(wd_sems[:label_size])).to(device)
    lexname_num = len(index2lexname)
    wd2lex = word2feature(data_all_idx, label_size, lexname_num, 'lexnames')
    wd_lex = label_multihot(wd2lex, lexname_num)
    wd_lex = torch.from_numpy(np.array(wd_lex)).to(device)
    rootaffix_num = len(index2rootaffix)
    wd2ra = word2feature(data_all_idx, label_size, rootaffix_num, 'root_affix')
    wd_ra = label_multihot(wd2ra, rootaffix_num)
    wd_ra = torch.from_numpy(np.array(wd_ra)).to(device)
    mask_s = mask_noFeature(label_size, wd2sem, sememe_num)
    mask_l = mask_noFeature(label_size, wd2lex, lexname_num)
    mask_r = mask_noFeature(label_size, wd2ra, rootaffix_num)

    model = Encoder(vocab_size=len(word2index),
                    embed_dim=word2vec.shape[1],
                    hidden_dim=300,
                    layers=1,
                    class_num=label_size,
                    sememe_num=sememe_num,
                    lexname_num=lexname_num,
                    rootaffix_num=rootaffix_num)
    model.embedding.weight.data = torch.from_numpy(word2vec)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam
    best_valid_accu = 0
    DEF_UPDATE = True
    for epoch in range(epoch_num):
        print('epoch: ', epoch)
        model.train()
        train_loss = 0
        label_list = list()
        pred_list = list()
        for words_t, definition_words_t in tqdm(train_dataloader,
                                                disable=verbose):
            optimizer.zero_grad()
            loss, _, indices = model('train',
                                     x=definition_words_t,
                                     w=words_t,
                                     ws=wd_sems,
                                     wl=wd_lex,
                                     wr=wd_ra,
                                     msk_s=mask_s,
                                     msk_l=mask_l,
                                     msk_r=mask_r,
                                     mode=MODE)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            predicted = indices[:, :100].detach().cpu().numpy().tolist()
            train_loss += loss.item()
            label_list.extend(words_t.detach().cpu().numpy())
            pred_list.extend(predicted)
        train_accu_1, train_accu_10, train_accu_100 = evaluate(
            label_list, pred_list)
        del label_list
        del pred_list
        gc.collect()
        print('train_loss: ', train_loss / len(train_dataset))
        print('train_accu(1/10/100): %.2f %.2F %.2f' %
              (train_accu_1, train_accu_10, train_accu_100))
        model.eval()
        with torch.no_grad():
            valid_loss = 0
            label_list = []
            pred_list = []
            for words_t, definition_words_t in tqdm(valid_dataloader,
                                                    disable=verbose):
                loss, _, indices = model('train',
                                         x=definition_words_t,
                                         w=words_t,
                                         ws=wd_sems,
                                         wl=wd_lex,
                                         wr=wd_ra,
                                         msk_s=mask_s,
                                         msk_l=mask_l,
                                         msk_r=mask_r,
                                         mode=MODE)
                predicted = indices[:, :100].detach().cpu().numpy().tolist()
                valid_loss += loss.item()
                label_list.extend(words_t.detach().cpu().numpy())
                pred_list.extend(predicted)
            valid_accu_1, valid_accu_10, valid_accu_100 = evaluate(
                label_list, pred_list)
            print('valid_loss: ', valid_loss / len(valid_dataset))
            print('valid_accu(1/10/100): %.2f %.2F %.2f' %
                  (valid_accu_1, valid_accu_10, valid_accu_100))
            del label_list
            del pred_list
            gc.collect()

            if valid_accu_10 > best_valid_accu:
                best_valid_accu = valid_accu_10
                print('-----best_valid_accu-----')
                #torch.save(model, 'saved.model')
                test_loss = 0
                label_list = []
                pred_list = []
                for words_t, definition_words_t in tqdm(test_dataloader,
                                                        disable=verbose):
                    indices = model('test',
                                    x=definition_words_t,
                                    w=words_t,
                                    ws=wd_sems,
                                    wl=wd_lex,
                                    wr=wd_ra,
                                    msk_s=mask_s,
                                    msk_l=mask_l,
                                    msk_r=mask_r,
                                    mode=MODE)
                    predicted = indices[:, :1000].detach().cpu().numpy(
                    ).tolist()
                    label_list.extend(words_t.detach().cpu().numpy())
                    pred_list.extend(predicted)
                test_accu_1, test_accu_10, test_accu_100, median, variance = evaluate_test(
                    label_list, pred_list)
                print('test_accu(1/10/100): %.2f %.2F %.2f %.2f %.2f' %
                      (test_accu_1, test_accu_10, test_accu_100, median,
                       variance))
                if epoch > 5:
                    json.dump((index2word[label_list]).tolist(),
                              open(mode + '_label_list.json', 'w'))
                    json.dump((index2word[np.array(pred_list)]).tolist(),
                              open(mode + '_pred_list.json', 'w'))
                del label_list
                del pred_list
                gc.collect()
예제 #13
0
class SolverNMsgMultipleDecodersDeepSteg(Solver):
    def __init__(self, config):
        super(SolverNMsgMultipleDecodersDeepSteg, self).__init__(config)
        logger.info("running multiple decoders solver!")

        # ------ create models ------
        self.dec_c_conv_dim = self.n_messages + 1 + 64
        self.build_models()

        # ------ make parallel ------
        self.enc_c = nn.DataParallel(self.enc_c)
        self.dec_c = nn.DataParallel(self.dec_c)
        self.dec_m = [nn.DataParallel(m) for m in self.dec_m]

        # ------ create optimizer ------
        dec_m_params = []
        for i in range(len(self.dec_m)):
            dec_m_params += list(self.dec_m[i].parameters())
        params = list(self.enc_c.parameters()) \
               + list(self.dec_c.parameters()) \
               + list(dec_m_params)
        self.opt = self.opt_type(params, lr=self.lr)
        self.lr_sched = StepLR(self.opt, step_size=20, gamma=0.5)

        # ------ send to cuda ------
        self.enc_c.to(self.device)
        self.dec_c.to(self.device)
        self.dec_m = [m.to(self.device) for m in self.dec_m]

        if self.load_ckpt_dir:
            self.load_models(self.load_ckpt_dir)

        logger.debug(self.enc_c)
        logger.debug(self.dec_c)
        logger.debug(self.dec_m)

    def build_models(self):
        super(SolverNMsgMultipleDecodersDeepSteg, self).build_models()

        self.enc_c = Encoder(block_type=self.block_type,
                             n_layers=self.config.enc_n_layers)

        self.dec_c = CarrierDecoder(conv_dim=self.dec_c_conv_dim,
                                    block_type=self.block_type,
                                    n_layers=self.config.dec_c_n_layers)

        self.dec_m = [MsgDecoder(conv_dim=self.dec_m_conv_dim,
                                 block_type=self.block_type) for _ in range(self.n_messages)]

    def save_models(self, suffix=''):
        logger.info(f"saving model to: {self.ckpt_dir}\n==> suffix: {suffix}")
        makedirs(join(self.ckpt_dir, suffix), exist_ok=True)
        torch.save(self.enc_c.state_dict(), join(self.ckpt_dir, suffix, "enc_c.ckpt"))
        torch.save(self.dec_c.state_dict(), join(self.ckpt_dir, suffix, "dec_c.ckpt"))
        for i,m in enumerate(self.dec_m):
            torch.save(m.state_dict(), join(self.ckpt_dir, suffix, f"dec_m_{i}.ckpt"))

    def load_models(self, ckpt_dir):
        self.enc_c.load_state_dict(torch.load(join(ckpt_dir, "enc_c.ckpt")))
        self.dec_c.load_state_dict(torch.load(join(ckpt_dir, "dec_c.ckpt")))
        for i,m in enumerate(self.dec_m):
            m.load_state_dict(torch.load(join(ckpt_dir, f"dec_m_{i}.ckpt")))
        logger.info("loaded models")

    def reset_grad(self):
        self.opt.zero_grad()

    def train_mode(self):
        super(SolverNMsgMultipleDecodersDeepSteg, self).train_mode()
        self.enc_c.train()
        self.dec_c.train()
        for model in self.dec_m:
            model.train()

    def eval_mode(self):
        super(SolverNMsgMultipleDecodersDeepSteg, self).eval_mode()
        self.enc_c.eval()
        self.dec_c.eval()
        for model in self.dec_m:
            model.eval()

    def step(self):
        self.opt.step()
        if self.cur_iter % len(self.train_loader) == 0:
            self.lr_sched.step()

    def incur_loss(self, carrier, carrier_reconst, msg, msg_reconst):
        n_messages = len(msg)
        losses_log = defaultdict(int)
        carrier, msg = carrier.to(self.device), [msg_i.to(self.device) for msg_i in msg]
        all_msg_loss = 0
        carrier_loss = self.reconstruction_loss(carrier_reconst, carrier, type=self.loss_type)
        for i in range(n_messages):
            msg_loss = self.reconstruction_loss(msg_reconst[i], msg[i], type=self.loss_type)
            all_msg_loss += msg_loss
        losses_log['carrier_loss'] = carrier_loss.item()
        losses_log['avg_msg_loss'] = all_msg_loss.item() / self.n_messages
        loss = self.lambda_carrier_loss * carrier_loss + self.lambda_msg_loss * all_msg_loss

        return loss, losses_log

    def forward(self, carrier, carrier_phase, msg):
        assert type(carrier) == torch.Tensor and type(msg) == list
        carrier, carrier_phase, msg = carrier.to(self.device), \
                                      carrier_phase.to(self.device), \
                                      [msg_i.to(self.device) for msg_i in msg]
        msg_reconst_list = []

        # create embedded carrier
        carrier_enc = self.enc_c(carrier)  # encode the carrier
        msg_enc = torch.cat(msg, dim=1)  # concat all msg_i into single tensor
        merged_enc = torch.cat((carrier_enc, carrier, msg_enc), dim=1)  # concat encodings on features axis
        carrier_reconst = self.dec_c(merged_enc)  # decode carrier

        if self.carrier_detach != -1 and self.cur_iter > self.carrier_detach:
            carrier_reconst = carrier_reconst.detach()

        # add stft noise to carrier
        if (self.add_stft_noise != -1 and self.cur_iter > self.add_stft_noise) or self.mode == 'test':
            self.stft.to(self.device)
            y = self.stft.inverse(carrier_reconst.squeeze(1), carrier_phase.squeeze(1))
            carrier_reconst_tag, _ = self.stft.transform(y.squeeze(1))
            carrier_reconst_tag = carrier_reconst_tag.unsqueeze(1)
            self.stft.to('cpu')
        else:
            carrier_reconst_tag = carrier_reconst

        # add different types of noise to carrier
        # (gaussian, speckle, salt and pepper)
        if self.add_carrier_noise:
            carrier_reconst_tag = add_noise(carrier_reconst_tag,
                                            self.add_carrier_noise,
                                            self.carrier_noise_norm)

        # decode messages from carrier
        for i in range(len(msg)):  # decode each msg_i using decoder_m_i
            msg_reconst = self.dec_m[i](carrier_reconst_tag)
            msg_reconst_list.append(msg_reconst)

        return carrier_reconst, msg_reconst_list
예제 #14
0
class KeplerModel(pl.LightningModule):
    def __init__(self):
        super(KeplerModel, self).__init__()

        #Initialize Model Parameters Using Config Properties
        self.model = Encoder(config['seq_length'], config['hidden_size'],
                             config['output_dim'], config['n_layers'])

        #Initialize a Cross Entropy Loss Criterion for Training
        self.criterion = torch.nn.CrossEntropyLoss()

    #Define a Forward Pass of the Model
    def forward(self, x, h):
        return self.model.forward(x, h)

    def training_step(self, batch, batch_idx):

        #Set Model to Training Mode
        self.model.train()

        #Unpack Data and Labels from Batch
        x, y = batch

        #Reshape Data into Shape (batch_size, 1, seq_length)
        x = x.view(x.size(0), -1, x.size(1))

        #Initalize the hidden state for forward pass
        h = self.model.init_hidden(x.size(0))

        #Zero out the model gradients to avoid accumulation
        self.model.zero_grad()

        #Forward Pass Through Model
        out, h = self.forward(x, h)

        #Calculate Cross Entropy Loss
        loss = self.criterion(out, y.long().squeeze())

        #Obtain Class Labels
        y_hat = torch.max(out, 1)[1]

        #Compute the balanced accuracy (weights based on number of ex. in each class)
        accuracy = balanced_accuracy_score(y, y_hat)

        #Compute weighted f1 score to account for class imbalance
        f1 = f1_score(y, y_hat, average='weighted')

        #Create metric object for tensorboard logging
        tensorboard_logs = {
            'train_loss': loss.item(),
            'accuracy': accuracy,
            'f1': f1
        }

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):

        #Set Model to Eval Mode
        self.model.eval()

        #Unpack data and labels from batch
        x, y = batch

        #Initialize Hidden State
        h = self.model.init_hidden(x.size(0))

        #Reshape Data into Shape (batch_size, 1, seq_length)
        x = x.view(x.size(0), -1, x.size(1))

        #Calculate Forward Pass of The Model
        out, h = self.forward(x, h)

        #Calculate Cross Entropy Loss
        loss = self.criterion(out, y.long().squeeze())

        #Calculate Class Indicies
        y_hat = torch.max(out, 1)[1]

        #Calculate Balanced Accuracy
        val_accuracy = torch.Tensor([balanced_accuracy_score(y, y_hat)])

        #Calculate Balanced Accuracy
        val_f1 = torch.Tensor([f1_score(y, y_hat, average='weighted')])

        #Create a metrics object
        metrics = {
            'val_loss': loss,
            'val_accuracy': val_accuracy,
            'val_f1': val_f1
        }

        return metrics

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'val_loss': avg_loss,
            'val_acc': avg_acc,
            'val_f1': avg_f1
        }

        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(KeplerDataset(mode="train"),
                          batch_size=64,
                          shuffle=True)

    @pl.data_loader
    def val_dataloader(self):
        # REQUIRED
        return DataLoader(KeplerDataset(mode="test"),
                          batch_size=128,
                          shuffle=True)
예제 #15
0
파일: train.py 프로젝트: kondo-kk/flu_trend
def train(region):
    np.random.seed(0)
    torch.manual_seed(0)

    input_len = 10
    encoder_units = 32
    decoder_units = 64
    encoder_rnn_layers = 3
    encoder_dropout = 0.2
    decoder_dropout = 0.2
    input_size = 2
    output_size = 1
    predict_len = 5
    batch_size = 16
    epochs = 500
    force_teacher = 0.8

    train_dataset, test_dataset, train_max, train_min = create_dataset(
        input_len, predict_len, region)
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    enc = Encoder(input_size, encoder_units, input_len,
                  encoder_rnn_layers, encoder_dropout)
    dec = Decoder(encoder_units*2, decoder_units, input_len,
                  input_len, decoder_dropout, output_size)

    optimizer = AdaBound(list(enc.parameters()) +
                         list(dec.parameters()), 0.01, final_lr=0.1)
    # optimizer = optim.Adam(list(enc.parameters()) + list(dec.parameters()), 0.01)
    criterion = nn.MSELoss()

    mb = master_bar(range(epochs))
    for ep in mb:
        train_loss = 0
        enc.train()
        dec.train()
        for encoder_input, decoder_input, target in progress_bar(train_loader, parent=mb):
            optimizer.zero_grad()
            enc_vec = enc(encoder_input)
            h = enc_vec[:, -1, :]
            _, c = dec.initHidden(batch_size)
            x = decoder_input[:, 0]
            pred = []
            for pi in range(predict_len):
                x, h, c = dec(x, h, c, enc_vec)
                rand = np.random.random()
                pred += [x]
                if rand < force_teacher:
                    x = decoder_input[:, pi]
            pred = torch.cat(pred, dim=1)
            # loss = quantile_loss(pred, target)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        test_loss = 0
        enc.eval()
        dec.eval()
        for encoder_input, decoder_input, target in progress_bar(test_loader, parent=mb):
            with torch.no_grad():
                enc_vec = enc(encoder_input)
                h = enc_vec[:, -1, :]
                _, c = dec.initHidden(batch_size)
                x = decoder_input[:, 0]
                pred = []
                for pi in range(predict_len):
                    x, h, c = dec(x, h, c, enc_vec)
                    pred += [x]
                pred = torch.cat(pred, dim=1)
            # loss = quantile_loss(pred, target)
            loss = criterion(pred, target)
            test_loss += loss.item()
        print(
            f"Epoch {ep} Train Loss {train_loss/len(train_loader)} Test Loss {test_loss/len(test_loader)}")

    if not os.path.exists("models"):
        os.mkdir("models")
    torch.save(enc.state_dict(), f"models/{region}_enc.pth")
    torch.save(dec.state_dict(), f"models/{region}_dec.pth")

    test_loader = DataLoader(test_dataset, batch_size=1,
                             shuffle=False, drop_last=False)

    rmse = 0
    p = 0
    predicted = []
    true_target = []
    enc.eval()
    dec.eval()
    for encoder_input, decoder_input, target in progress_bar(test_loader, parent=mb):
        with torch.no_grad():
            enc_vec = enc(encoder_input)
            x = decoder_input[:, 0]
            h, c = dec.initHidden(1)
            pred = []
            for pi in range(predict_len):
                x, h, c = dec(x, h, c, enc_vec)
                pred += [x]
            pred = torch.cat(pred, dim=1)
            predicted += [pred[0, p].item()]
            true_target += [target[0, p].item()]
    predicted = np.array(predicted).reshape(1, -1)
    predicted = predicted * (train_max - train_min) + train_min
    true_target = np.array(true_target).reshape(1, -1)
    true_target = true_target * (train_max - train_min) + train_min
    rmse, peasonr = calc_metric(predicted, true_target)
    print(f"{region} RMSE {rmse}")
    print(f"{region} r {peasonr[0]}")
    return f"{region} RMSE {rmse} r {peasonr[0]}"
예제 #16
0
def main(args):

    # GPU
    if torch.cuda.is_available():
        print('Using GPU')
    else:
        print('Using CPU')

    # Create directory to save model
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image Preprocessing
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Split data into 'Train', 'Validate'

    img_name_report = pd.read_csv(args.img_report_path)
    data_total_size = len(img_name_report)
    print('Data Total Size:{}'.format(data_total_size))
    train_size = int(data_total_size * 0.8)
    train_data = img_name_report.sample(n=train_size)
    img_name_report.drop(list(train_data.index), inplace=True)
    val_data = img_name_report
    train_data.reset_index(level=0, inplace=True)
    val_data.reset_index(level=0, inplace=True)
    print('Training Data:{}'.format(len(train_data)))
    print('Valdiation Data:{}'.format(len(val_data)))

    # Data Loader
    train_loader = get_loader(args.image_dir,
                              vocab,
                              train_data,
                              transform,
                              args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              split='Train')
    val_loader = get_loader(args.image_dir,
                            vocab,
                            val_data,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            split='Val')

    # Build Models
    encoder = Encoder().to(device)
    decoder = Decoder(embed_dim=args.embed_size,
                      decoder_dim=args.hidden_size,
                      vocab_size=len(vocab)).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()

    encoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, encoder.parameters()),
                                         lr=args.encoder_lr)
    decoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, decoder.parameters()),
                                         lr=args.decoder_lr)

    encoder.train()
    decoder.train()

    total_step = len(train_loader)
    # Train the models
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(train_loader):

            images = images.to(device)
            captions = captions.to(device)

            #
            # Training
            #
            encoded_img = encoder(images)
            scores, cap_sorted, decode_len = decoder(encoded_img, captions,
                                                     lengths)

            # Ignore <start>
            targets = cap_sorted[:, 1:]

            # Remove <pad>
            scores = pack_padded_sequence(scores, decode_len,
                                          batch_first=True)[0]
            targets = pack_padded_sequence(targets,
                                           decode_len,
                                           batch_first=True)[0]

            # optimization
            loss = criterion(scores, targets)
            decoder_optimizer.zero_grad()
            encoder_optimizer.zero_grad()
            loss.backward()

            # update weights
            decoder_optimizer.step()
            encoder_optimizer.step()

            #
            #  Validation
            #
            if i % args.validation_step == 0:
                validation(val_loader, encoder, decoder, criterion)

            # Print log info
            if i % args.train_log_step == 0:
                print(
                    '[Training] -  Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch + 1, args.num_epochs, i + 1, total_step,
                            loss.item(), np.exp(loss.item())))
예제 #17
0
def main():
    options = parse_args()
    torch.manual_seed(options.seed)
    is_cuda = use_cuda and not options.no_cuda
    hardware = "cuda" if is_cuda else "cpu"
    device = torch.device(hardware)

    checkpoint = (load_checkpoint(options.checkpoint, cuda=is_cuda)
                  if options.checkpoint else default_checkpoint)
    print("Running {} epochs on {}".format(options.num_epochs, hardware))
    encoder_checkpoint = checkpoint["model"].get("encoder")
    decoder_checkpoint = checkpoint["model"].get("decoder")
    if encoder_checkpoint is not None:
        print(("Resuming from - Epoch {}: "
               "Train Accuracy = {train_accuracy:.5f}, "
               "Train Loss = {train_loss:.5f}, "
               "Validation Accuracy = {validation_accuracy:.5f}, "
               "Validation Loss = {validation_loss:.5f}, ").format(
                   checkpoint["epoch"],
                   train_accuracy=checkpoint["train_accuracy"][-1],
                   train_loss=checkpoint["train_losses"][-1],
                   validation_accuracy=checkpoint["validation_accuracy"][-1],
                   validation_loss=checkpoint["validation_losses"][-1],
               ))

    train_dataset = CrohmeDataset(gt_train,
                                  tokensfile,
                                  root=root,
                                  crop=options.crop,
                                  transform=transformers)
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=options.batch_size,
        shuffle=True,
        num_workers=options.num_workers,
        collate_fn=collate_batch,
    )
    validation_dataset = CrohmeDataset(gt_validation,
                                       tokensfile,
                                       root=root,
                                       crop=options.crop,
                                       transform=transformers)
    validation_data_loader = DataLoader(
        validation_dataset,
        batch_size=options.batch_size,
        shuffle=True,
        num_workers=options.num_workers,
        collate_fn=collate_batch,
    )
    criterion = nn.CrossEntropyLoss().to(device)
    enc = Encoder(img_channels=3,
                  dropout_rate=options.dropout_rate,
                  checkpoint=encoder_checkpoint).to(device)
    dec = Decoder(
        len(train_dataset.id_to_token),
        low_res_shape,
        high_res_shape,
        checkpoint=decoder_checkpoint,
        device=device,
    ).to(device)
    enc.train()
    dec.train()

    enc_params_to_optimise = [
        param for param in enc.parameters() if param.requires_grad
    ]
    dec_params_to_optimise = [
        param for param in dec.parameters() if param.requires_grad
    ]
    params_to_optimise = [*enc_params_to_optimise, *dec_params_to_optimise]
    optimiser = optim.Adadelta(params_to_optimise,
                               lr=options.lr,
                               weight_decay=options.weight_decay)
    optimiser_state = checkpoint.get("optimiser")
    if optimiser_state:
        optimiser.load_state_dict(optimiser_state)
    # Set the learning rate instead of using the previous state.
    # The scheduler somehow overwrites the LR to the initial LR after loading,
    # which would always reset it to the first used learning rate instead of
    # the one from the previous checkpoint. So might as well set it manually.
    for param_group in optimiser.param_groups:
        param_group["initial_lr"] = options.lr
    # Decay learning rate by a factor of lr_factor (default: 0.1)
    # every lr_epochs (default: 3)
    lr_scheduler = optim.lr_scheduler.StepLR(optimiser,
                                             step_size=options.lr_epochs,
                                             gamma=options.lr_factor)

    train(
        enc,
        dec,
        optimiser,
        criterion,
        train_data_loader,
        validation_data_loader,
        teacher_forcing_ratio=options.teacher_forcing,
        lr_scheduler=lr_scheduler,
        print_epochs=options.print_epochs,
        device=device,
        num_epochs=options.num_epochs,
        checkpoint=checkpoint,
        prefix=options.prefix,
        max_grad_norm=options.max_grad_norm,
    )
예제 #18
0
encoder = Encoder(BasicBlock).to(device)
decoder = Decoder(BasicBlock).to(device)
discriminator = Discriminator(BasicBlock).to(device)

ce_loss = nn.CrossEntropyLoss()
l1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()

pre_encoder_optim = optim.Adam(pre_encoder.parameters(), lr = 0.0002, weight_decay=0)
encoder_optim = optim.Adam(encoder.parameters(), lr = 0.0002, weight_decay=0)
decoder_optim = optim.Adam(decoder.parameters(), lr = 0.0002, weight_decay=0)
discriminator_optim = optim.Adam(discriminator.parameters(), lr = learning_rate, weight_decay=lmbda)


pre_encoder = pre_encoder.train()
encoder = encoder.train()
decoder = decoder.train()
discriminator = discriminator.train()


for epoch in range(epochs):
    for i, batch in enumerate(train_loader):
        
        image, real_image = batch
        image = image.to(device)
        real_image = real_image.to(device)
        real_label = torch.ones([image.shape[0]]).long().to(device)
        fake_label = torch.zeros([image.shape[0]]).long().to(device)
        
        pre_encoder_optim.zero_grad()
        encoder_optim.zero_grad()
def train_dynamics(env, args, writer=None):
    """
    Trains the Dynamics module. Supervised.

    Arguments:
    env: the initialized environment (rllab/gym)
    args: input arguments
    writer: initialized summary writer for tensorboard
    """
    args.action_space = env.action_space

    # Initialize models
    enc = Encoder(env.observation_space.shape[0],
                  args.dim,
                  use_conv=args.use_conv)
    dec = Decoder(env.observation_space.shape[0],
                  args.dim,
                  use_conv=args.use_conv)
    d_module = D_Module(env.action_space.shape[0], args.dim, args.discrete)

    if args.from_checkpoint is not None:
        results_dict = torch.load(args.from_checkpoint)
        enc.load_state_dict(results_dict['enc'])
        dec.load_state_dict(results_dict['dec'])
        d_module.load_state_dict(results_dict['d_module'])

    all_params = chain(enc.parameters(), dec.parameters(),
                       d_module.parameters())

    if args.transfer:
        for p in enc.parameters():
            p.requires_grad = False

        for p in dec.parameters():
            p.requires_grad = False
        all_params = d_module.parameters()

    optimizer = torch.optim.Adam(all_params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    if args.gpu:
        enc = enc.cuda()
        dec = dec.cuda()
        d_module = d_module.cuda()

    # Initialize datasets
    val_loader = None
    train_dataset = DynamicsDataset(args.train_set,
                                    args.train_size,
                                    batch=args.train_batch,
                                    rollout=args.rollout)
    val_dataset = DynamicsDataset(args.test_set,
                                  5000,
                                  batch=args.test_batch,
                                  rollout=args.rollout)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers)

    results_dict = {
        'dec_losses': [],
        'forward_losses': [],
        'inverse_losses': [],
        'total_losses': [],
        'enc': None,
        'dec': None,
        'd_module': None,
        'd_init': None,
        'args': args
    }

    total_action_taken = 0
    correct_predicted_a_hat = 0

    # create the mask here for re-weighting
    dec_mask = None
    if args.dec_mask is not None:
        dec_mask = torch.ones(9)
        game_vocab = dict([
            (b, a)
            for a, b in enumerate(sorted(env.game.all_possible_features()))
        ])
        dec_mask[game_vocab['Agent']] = args.dec_mask
        dec_mask[game_vocab['Goal']] = args.dec_mask
        dec_mask = dec_mask.expand(args.batch_size, args.maze_length,
                                   args.maze_length, 9).contiguous().view(-1)
        dec_mask = Variable(dec_mask, requires_grad=False)
        if args.gpu:
            dec_mask = dec_mask.cuda()

    for epoch in range(1, args.num_epochs + 1):
        enc.train()
        dec.train()
        d_module.train()

        if args.framework == "mazebase":
            d_init.train()

        # for measuring the accuracy
        train_acc = 0
        current_epoch_actions = 0
        current_epoch_predicted_a_hat = 0

        start = time.time()
        for i, (states, target_actions) in enumerate(train_loader):

            optimizer.zero_grad()

            if args.framework != "mazebase":
                forward_loss, inv_loss, dec_loss, recon_loss, model_loss, _, _ = forward_planning(
                    i, states, target_actions, enc, dec, d_module, args)
            else:
                forward_loss, inv_loss, dec_loss, recon_loss, model_loss, current_epoch_predicted_a_hat, current_epoch_actions = multiple_forward(
                    i, states, target_actions, enc, dec, d_module, args,
                    d_init, dec_mask)

            loss = forward_loss + args.inv_loss_coef * inv_loss + \
                        args.dec_loss_coef * dec_loss

            if i % args.log_interval == 0:
                log(
                    'Epoch [{}/{}]\tIter [{}/{}]\t'.format(
                        epoch, args.num_epochs, i+1, len(
                        train_dataset)//args.batch_size) + \
                    'Time: {:.2f}\t'.format(time.time() - start) + \
                    'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0]) + \
                    'Forward Loss: {:.2f}\t'.format(forward_loss.data[0] ) + \
                    'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0]) + \
                    'Loss: {:.2f}\t'.format(loss.data[0]))

                results_dict['dec_losses'].append(dec_loss.data[0])
                results_dict['forward_losses'].append(forward_loss.data[0])
                results_dict['inverse_losses'].append(inv_loss.data[0])
                results_dict['total_losses'].append(loss.data[0])

                # write the summaries here
                if writer:
                    writer.add_scalar('dynamics/total_loss', loss.data[0],
                                      epoch)
                    writer.add_scalar('dynamics/decoder', dec_loss.data[0],
                                      epoch)
                    writer.add_scalar('dynamics/reconstruction_loss',
                                      recon_loss.data[0], epoch)
                    writer.add_scalar('dynamics/next_state_prediction_loss',
                                      model_loss.data[0], epoch)
                    writer.add_scalar('dynamics/inv_loss', inv_loss.data[0],
                                      epoch)
                    writer.add_scalar('dynamics/forward_loss',
                                      forward_loss.data[0], epoch)

                    writer.add_scalars(
                        'dynamics/all_losses', {
                            "total_loss": loss.data[0],
                            "reconstruction_loss": recon_loss.data[0],
                            "next_state_prediction_loss": model_loss.data[0],
                            "decoder_loss": dec_loss.data[0],
                            "inv_loss": inv_loss.data[0],
                            "forward_loss": forward_loss.data[0],
                        }, epoch)

            loss.backward()

            correct_predicted_a_hat += current_epoch_predicted_a_hat
            total_action_taken += current_epoch_actions

            # does it not work at all without grad clipping ?
            torch.nn.utils.clip_grad_norm(all_params, args.max_grad_norm)
            optimizer.step()

            # maybe add the generated image to add the logs
            # writer.add_image()

        # Run validation
        if val_loader is not None:
            enc.eval()
            dec.eval()
            d_module.eval()
            forward_loss, inv_loss, dec_loss = 0, 0, 0
            for i, (states, target_actions) in enumerate(val_loader):
                f_loss, i_loss, d_loss, _, _, _, _ = forward_planning(
                    i, states, target_actions, enc, dec, d_module, args)
                forward_loss += f_loss
                inv_loss += i_loss
                dec_loss += d_loss
            loss = forward_loss + args.inv_loss_coef * inv_loss + \
                    args.dec_loss_coef * dec_loss
            if writer:
                writer.add_scalar('val/forward_loss', forward_loss.data[0] / i,
                                  epoch)
                writer.add_scalar('val/inverse_loss', inv_loss.data[0] / i,
                                  epoch)
                writer.add_scalar('val/decoder_loss', dec_loss.data[0] / i,
                                  epoch)
            log(
                '[Validation]\t' + \
                'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0] / i) + \
                'Forward Loss: {:.2f}\t'.format(forward_loss.data[0] / i) + \
                'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0] / i) + \
                'Loss: {:.2f}\t'.format(loss.data[0] / i))
        if epoch % args.checkpoint == 0:
            results_dict['enc'] = enc.state_dict()
            results_dict['dec'] = dec.state_dict()
            results_dict['d_module'] = d_module.state_dict()
            if args.framework == "mazebase":
                results_dict['d_init'] = d_init.state_dict()
            torch.save(
                results_dict,
                os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch))
            log('Saved model %s' % epoch)

    results_dict['enc'] = enc.state_dict()
    results_dict['dec'] = dec.state_dict()
    results_dict['d_module'] = d_module.state_dict()
    torch.save(results_dict,
               os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch))
    print(os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch))
예제 #20
0
# gowthami - what is this best score? and what is this train loss small?
# best_score = 0.5
## small == 3000 ## smallest value is 3000
# train_loss_small = 3000

# gowthami - did they use it in the final run? chek with Yexin
print(use_decay_learning, use_linearly_decay)
print('start from %s ' % start)
n_iter = 0

starttime = time.time()

if (opt.model_load_path == ''):
    for epoch in range(start, num_epochs):
        netE.train()
        netD2.train()
        netG.train()
        netD.train()

        for i, data in enumerate(trainloader, 0):
            n_iter = n_iter + 1

            #         targets = data[1].to(device)
            #         print(torch.unique(targets, return_counts=True) )

            # optimize discriminator tfd times
            for t in range(tfd):
                netD.zero_grad()
                netE.zero_grad()
                netG.zero_grad()
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)
        if swa_params:
            encoder_swa = nn.DataParallel(encoder_swa)
            decoder_swa = nn.DataParallel(decoder_swa)

    best_f1 = 0.0
    for epoch in range(resume, args.epochs):
        training = True
        if args.test_model:
            training = False

        if training:
            # train
            if finetune_encoder:
                encoder.train()
                if swa_params:
                    encoder_swa.train()
            decoder.train()
            if swa_params:
                decoder_swa.train()
            for i, batch in enumerate(dataloader):
                iterations += 1
                images = batch[0]
                labels = batch[1]
                label_lengths = batch[2]
                labels_classification = batch[3].to('cuda')
                if args.visualize_batch:
                    visualize_batch_fn(images, labels, label_lengths)
                images = images.to('cuda')
                labels = labels.to('cuda')
예제 #22
0
class SolverNMsgCond(Solver):
    def __init__(self, config):
        super(SolverNMsgCond, self).__init__(config)
        print("==> running conditional solver!")

        # ------ create models ------
        self.dec_c_conv_dim = self.n_messages * (self.n_messages + 1) + 1 + 64
        # self.dec_c_conv_dim = (self.n_messages+1) * 64
        self.build_models()

        # ------ make parallel ------
        self.enc_c = nn.DataParallel(self.enc_c)
        self.enc_m = nn.DataParallel(self.enc_m)
        self.dec_m = nn.DataParallel(self.dec_m)
        self.dec_c = nn.DataParallel(self.dec_c)

        # ------ create optimizers ------
        params = list(self.enc_m.parameters()) \
               + list(self.enc_c.parameters()) \
               + list(self.dec_c.parameters()) \
               + list(self.dec_m.parameters())
        self.opt = self.opt_type(params, lr=self.lr)
        self.lr_sched = StepLR(self.opt, step_size=20, gamma=0.5)

        # ------ send to cuda ------
        self.enc_c.to(self.device)
        self.enc_m.to(self.device)
        self.dec_m.to(self.device)
        self.dec_c.to(self.device)

        if self.load_ckpt_dir:
            self.load_models(self.load_ckpt_dir)

        logger.debug(self.enc_m)
        logger.debug(self.dec_c)
        logger.debug(self.dec_m)

    def build_models(self):
        super(SolverNMsgCond, self).build_models()

        self.enc_m = Encoder(conv_dim=1 + self.n_messages,
                             block_type=self.block_type,
                             n_layers=self.config.enc_n_layers)

        self.enc_c = Encoder(conv_dim=1,
                             block_type=self.block_type,
                             n_layers=self.config.enc_n_layers)

        self.dec_c = CarrierDecoder(conv_dim=self.dec_c_conv_dim,
                                    block_type=self.block_type,
                                    n_layers=self.config.dec_c_n_layers)

        self.dec_m = MsgDecoder(conv_dim=self.dec_m_conv_dim + self.n_messages,
                                block_type=self.block_type)

    def save_models(self, suffix=''):
        logger.info(f"saving model to: {self.ckpt_dir}\n==> suffix: {suffix}")
        makedirs(join(self.ckpt_dir, suffix), exist_ok=True)
        torch.save(self.enc_c.state_dict(),
                   join(self.ckpt_dir, suffix, "enc_c.ckpt"))
        torch.save(self.enc_m.state_dict(),
                   join(self.ckpt_dir, suffix, "enc_m.ckpt"))
        torch.save(self.dec_c.state_dict(),
                   join(self.ckpt_dir, suffix, "dec_c.ckpt"))
        torch.save(self.dec_m.state_dict(),
                   join(self.ckpt_dir, suffix, "dec_m.ckpt"))

    def load_models(self, ckpt_dir):
        self.enc_c.load_state_dict(torch.load(join(ckpt_dir, "enc_c.ckpt")))
        self.enc_m.load_state_dict(torch.load(join(ckpt_dir, "enc_m.ckpt")))
        self.dec_c.load_state_dict(torch.load(join(ckpt_dir, "dec_c.ckpt")))
        self.dec_m.load_state_dict(torch.load(join(ckpt_dir, "dec_m.ckpt")))
        logger.info("loaded models")

    def reset_grad(self):
        self.opt.zero_grad()

    def train_mode(self):
        super(SolverNMsgCond, self).train_mode()
        self.enc_m.train()
        self.enc_c.train()
        self.dec_c.train()
        self.dec_m.train()

    def eval_mode(self):
        super(SolverNMsgCond, self).eval_mode()
        self.enc_m.train()
        self.enc_c.train()
        self.dec_c.train()
        self.dec_m.train()

    def step(self):
        self.opt.step()
        if self.cur_iter % len(self.train_loader) == 0:
            self.lr_sched.step()

    def incur_loss(self, carrier, carrier_reconst, msg, msg_reconst):
        n_messages = len(msg)
        losses_log = defaultdict(int)
        carrier, msg = carrier.to(
            self.device), [msg_i.to(self.device) for msg_i in msg]
        all_msg_loss = 0
        carrier_loss = self.reconstruction_loss(carrier_reconst,
                                                carrier,
                                                type=self.loss_type)
        for i in range(n_messages):
            msg_loss = self.reconstruction_loss(msg_reconst[i],
                                                msg[i],
                                                type=self.loss_type)
            all_msg_loss += msg_loss
        losses_log['carrier_loss'] = carrier_loss.item()
        losses_log['avg_msg_loss'] = all_msg_loss.item() / self.n_messages
        loss = self.lambda_carrier_loss * carrier_loss + self.lambda_msg_loss * all_msg_loss

        return loss, losses_log

    def forward(self, carrier, carrier_phase, msg):
        assert type(carrier) == torch.Tensor and type(msg) == list
        batch_size = carrier.shape[0]
        carrier, carrier_phase, msg = carrier.to(
            self.device), carrier_phase.to(
                self.device), [msg_i.to(self.device) for msg_i in msg]
        msg_encoded_list = []
        msg_reconst_list = []

        # encode carrier
        carrier_enc = self.enc_c(carrier)

        # encoder mesasges
        for i in range(self.n_messages):
            # create one-hot vectors for msg index
            cond = torch.tensor(()).new_full((batch_size, ), i)
            cond = self.label2onehot(cond, self.n_messages).to(self.device)
            # concat conditioning vectors to input
            msg_i = self.concat_cond(msg[i], cond)
            msg_encoded_list.append(msg_i)

        # merge encodings and reconstruct carrier
        msg_enc = torch.cat(msg_encoded_list, dim=1)
        merged_enc = torch.cat((carrier, carrier_enc, msg_enc),
                               dim=1)  # concat encodings on features axis
        carrier_reconst = self.dec_c(merged_enc)

        if self.carrier_detach != -1 and self.cur_iter > self.carrier_detach:
            carrier_reconst = carrier_reconst.detach()

        # add stft noise to carrier
        if (self.add_stft_noise != -1 and
                self.cur_iter > self.add_stft_noise) or self.mode == 'test':
            self.stft.to(self.device)
            y = self.stft.inverse(carrier_reconst.squeeze(1),
                                  carrier_phase.squeeze(1))
            carrier_reconst_tag, _ = self.stft.transform(y.squeeze(1))
            carrier_reconst_tag = carrier_reconst_tag.unsqueeze(1)
            self.stft.to('cpu')
        else:
            carrier_reconst_tag = carrier_reconst

        # decode messages from carrier
        for i in range(self.n_messages):
            cond = torch.tensor(()).new_full((batch_size, ), i)
            cond = self.label2onehot(cond, self.n_messages).to(self.device)
            msg_reconst = self.dec_m(
                self.concat_cond(carrier_reconst_tag, cond))
            msg_reconst_list.append(msg_reconst)

        return carrier_reconst, msg_reconst_list

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def concat_cond(self, x, c):
        # Replicate spatially and concatenate domain information.
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        return torch.cat([x, c], dim=1)
예제 #23
0
파일: main.py 프로젝트: yugenlgy/MultiRD
def main(epoch_num, batch_size, verbose, UNSEEN, SEEN, MODE):
    [
        hownet_file, sememe_file, word_index_file, word_vector_file,
        dictionary_file, word_cilinClass_file
    ] = [
        'hownet.json', 'sememe.json', 'word_index.json', 'word_vector.npy',
        'dictionary_sense.json', 'word_cilinClass.json'
    ]
    word2index, index2word, word2vec, sememe_num, label_size, label_size_chara, word_defi_idx_all = load_data(
        hownet_file, sememe_file, word_index_file, word_vector_file,
        dictionary_file, word_cilinClass_file)
    (word_defi_idx_TrainDev, word_defi_idx_seen, word_defi_idx_test2000,
     word_defi_idx_test200, word_defi_idx_test272) = word_defi_idx_all
    index2word = np.array(index2word)
    length = len(word_defi_idx_TrainDev)
    valid_dataset = MyDataset(word_defi_idx_TrainDev[int(0.9 * length):])
    test_dataset = MyDataset(word_defi_idx_test2000 + word_defi_idx_test200 +
                             word_defi_idx_test272)
    if SEEN:
        mode = 'S_' + MODE
        print('*METHOD: Seen defi.')
        print('*TRAIN: [Train + allSeen(2000+200+272)]')
        print('*TEST: [2000rand1 + 200desc + 272desc]')
        train_dataset = MyDataset(word_defi_idx_TrainDev[:int(0.9 * length)] +
                                  word_defi_idx_seen)
    elif UNSEEN:
        mode = 'U_' + MODE
        print('*METHOD: Unseen All words and defi.')
        print('*TRAIN: [Train]')
        print('*TEST: [2000rand1 + 200desc + 272desc]')
        train_dataset = MyDataset(word_defi_idx_TrainDev[:int(0.9 * length)])
    print('*MODE: [%s]' % mode)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=my_collate_fn)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=my_collate_fn)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=my_collate_fn_test)

    print('Train dataset: ', len(train_dataset))
    print('Valid dataset: ', len(valid_dataset))
    print('Test dataset: ', len(test_dataset))
    word_defi_idx = word_defi_idx_TrainDev + word_defi_idx_seen

    wd2sem = word2sememe(word_defi_idx, len(word2index), sememe_num)
    wd_sems = label_multihot(wd2sem, sememe_num)
    wd_sems = torch.from_numpy(np.array(wd_sems[:label_size])).to(device)
    wd_POSs = label_multihot(word2POS(word_defi_idx, len(word2index), 13), 13)
    wd_POSs = torch.from_numpy(np.array(wd_POSs[:label_size])).to(device)
    wd_charas = label_multihot(
        word2chara(word_defi_idx, len(word2index), label_size_chara),
        label_size_chara)
    wd_charas = torch.from_numpy(np.array(wd_charas[:label_size])).to(device)
    wd2Cilin1 = word2Cn(word_defi_idx, len(word2index), 'C1', 13)
    wd_C1 = label_multihot(wd2Cilin1, 13)  #13 96 1426 4098
    wd_C1 = torch.from_numpy(np.array(wd_C1[:label_size])).to(device)
    wd_C2 = label_multihot(word2Cn(word_defi_idx, len(word2index), 'C2', 96),
                           96)
    wd_C2 = torch.from_numpy(np.array(wd_C2[:label_size])).to(device)
    wd_C3 = label_multihot(word2Cn(word_defi_idx, len(word2index), 'C3', 1426),
                           1426)
    wd_C3 = torch.from_numpy(np.array(wd_C3[:label_size])).to(device)
    wd_C4 = label_multihot(word2Cn(word_defi_idx, len(word2index), 'C4', 4098),
                           4098)
    wd_C4 = torch.from_numpy(np.array(wd_C4[:label_size])).to(device)
    '''wd2Cilin = word2Cn(word_defi_idx, len(word2index), 'C', 5633)
    wd_C0 = label_multihot(wd2Cilin, 5633) 
    wd_C0 = torch.from_numpy(np.array(wd_C0[:label_size])).to(device)
    wd_C = [wd_C1, wd_C2, wd_C3, wd_C4, wd_C0]
    '''
    wd_C = [wd_C1, wd_C2, wd_C3, wd_C4]
    #----------mask of no sememes
    print('calculating mask of no sememes...')
    mask_s = torch.zeros(label_size, dtype=torch.float32, device=device)
    for i in range(label_size):
        sems = set(wd2sem[i].detach().cpu().numpy().tolist()) - set(
            [sememe_num])
        if len(sems) == 0:
            mask_s[i] = 1

    mask_c = torch.zeros(label_size, dtype=torch.float32, device=device)
    for i in range(label_size):
        cc = set(wd2Cilin1[i].detach().cpu().numpy().tolist()) - set([13])
        if len(cc) == 0:
            mask_c[i] = 1

    model = Encoder(vocab_size=len(word2index),
                    embed_dim=word2vec.shape[1],
                    hidden_dim=200,
                    layers=1,
                    class_num=label_size,
                    sememe_num=sememe_num,
                    chara_num=label_size_chara)
    model.embedding.weight.data = torch.from_numpy(word2vec)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam
    best_valid_accu = 0
    DEF_UPDATE = True
    for epoch in range(epoch_num):
        print('epoch: ', epoch)
        model.train()
        train_loss = 0
        label_list = list()
        pred_list = list()
        for words_t, sememes_t, definition_words_t, POS_t, sememes, POSs, charas_t, C, C_t in tqdm(
                train_dataloader, disable=verbose):
            optimizer.zero_grad()
            loss, _, indices = model('train',
                                     x=definition_words_t,
                                     w=words_t,
                                     ws=wd_sems,
                                     wP=wd_POSs,
                                     wc=wd_charas,
                                     wC=wd_C,
                                     msk_s=mask_s,
                                     msk_c=mask_c,
                                     mode=MODE)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            predicted = indices[:, :100].detach().cpu().numpy().tolist()
            train_loss += loss.item()
            label_list.extend(words_t.detach().cpu().numpy())
            pred_list.extend(predicted)
        train_accu_1, train_accu_10, train_accu_100 = evaluate(
            label_list, pred_list)
        del label_list
        del pred_list
        gc.collect()
        print('train_loss: ', train_loss / len(train_dataset))
        print('train_accu(1/10/100): %.2f %.2F %.2f' %
              (train_accu_1, train_accu_10, train_accu_100))
        model.eval()
        with torch.no_grad():
            valid_loss = 0
            label_list = []
            pred_list = []
            for words_t, sememes_t, definition_words_t, POS_t, sememes, POSs, charas_t, C, C_t in tqdm(
                    valid_dataloader, disable=verbose):
                loss, _, indices = model('train',
                                         x=definition_words_t,
                                         w=words_t,
                                         ws=wd_sems,
                                         wP=wd_POSs,
                                         wc=wd_charas,
                                         wC=wd_C,
                                         msk_s=mask_s,
                                         msk_c=mask_c,
                                         mode=MODE)
                predicted = indices[:, :100].detach().cpu().numpy().tolist()
                valid_loss += loss.item()
                label_list.extend(words_t.detach().cpu().numpy())
                pred_list.extend(predicted)
            valid_accu_1, valid_accu_10, valid_accu_100 = evaluate(
                label_list, pred_list)
            print('valid_loss: ', valid_loss / len(valid_dataset))
            print('valid_accu(1/10/100): %.2f %.2F %.2f' %
                  (valid_accu_1, valid_accu_10, valid_accu_100))
            del label_list
            del pred_list
            gc.collect()

            if valid_accu_10 > best_valid_accu:
                best_valid_accu = valid_accu_10
                print('-----best_valid_accu-----')
                #torch.save(model, 'saved.model')
                label_list = []
                pred_list = []
                for words_t, definition_words_t in tqdm(test_dataloader,
                                                        disable=verbose):
                    indices = model('test',
                                    x=definition_words_t,
                                    w=words_t,
                                    ws=wd_sems,
                                    wP=wd_POSs,
                                    wc=wd_charas,
                                    wC=wd_C,
                                    msk_s=mask_s,
                                    msk_c=mask_c,
                                    mode=MODE)
                    predicted = indices[:, :1000].detach().cpu().numpy(
                    ).tolist()
                    label_list.extend(words_t.detach().cpu().numpy())
                    pred_list.extend(predicted)
                test_accu_1, test_accu_10, test_accu_100, median, variance = evaluate_test(
                    label_list, pred_list)
                print('test_accu(1/10/100): %.2f %.2F %.2f %.1f %.2f' %
                      (test_accu_1, test_accu_10, test_accu_100, median,
                       variance))
                if epoch > 10:
                    json.dump((index2word[label_list]).tolist(),
                              open(mode + '_label_list.json', 'w'))
                    json.dump((index2word[np.array(pred_list)]).tolist(),
                              open(mode + '_pred_list.json', 'w'))
                del label_list
                del pred_list
                gc.collect()
예제 #24
0
def main(args):
    model_prefix = '{}_{}'.format(args.model_type, args.train_id)  #模型前缀名

    # 各个路径的参数
    log_path = args.LOG_DIR + model_prefix + '/'
    checkpoint_path = args.CHK_DIR + model_prefix + '/'
    result_path = args.RESULT_DIR + model_prefix + '/'
    cp_file = checkpoint_path + "best_model.pth.tar"
    init_epoch = 0

    #创建对应的文件夹
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)

    ## 初始化log
    set_logger(os.path.join(log_path, 'train.log'))

    ## 保存参数,即copy一份cogfigs.yaml方便复现
    with open(log_path + 'args.yaml', 'w') as f:
        for k, v in args.__dict__.items():
            f.write('{}: {}\n'.format(k, v))

    logging.info('Training model: {}'.format(model_prefix))  #写log

    ## 构建相应的词表
    setup(args, clear=True)
    print(args.__dict__)

    # 设置src and tgt的语言,模型可以同时做中翻英或者英翻中,需要在此处设置
    src, tgt = 'en', 'zh'

    maps = {
        'en': args.TRAIN_VOCAB_EN,
        'zh': args.TRAIN_VOCAB_ZH
    }  #这个maps字典存的是对应词表地址
    vocab_src = read_vocab(maps[src])  #按照地址读vocab进去
    tok_src = Tokenizer(language=src,
                        vocab=vocab_src,
                        encoding_length=args.MAX_INPUT_LENGTH
                        )  #然后初始化tokenizer类,这个类在utils函数中,可以完成词表,encode等等的操作
    vocab_tgt = read_vocab(maps[tgt])  #tgt同理
    tok_tgt = Tokenizer(language=tgt,
                        vocab=vocab_tgt,
                        encoding_length=args.MAX_INPUT_LENGTH)
    logging.info('Vocab size src/tgt:{}/{}'.format(len(vocab_src),
                                                   len(vocab_tgt)))  #写log

    ## 构造training, validation, and testing dataloaders,这个在dataloader.py中,得到对应batch的数据
    train_loader, val_loader, test_loader = create_split_loaders(
        args.DATA_DIR, (tok_src, tok_tgt),
        args.batch_size,
        args.MAX_VID_LENGTH, (src, tgt),
        num_workers=4,
        pin_memory=True)
    logging.info('train/val/test size: {}/{}/{}'.format(
        len(train_loader), len(val_loader), len(test_loader)))  #写log

    ## 初始化模型
    if args.model_type == 's2s':  #seq2seq,不过目前似乎只有这一种type
        encoder = Encoder(vocab_size=len(vocab_src),
                          embed_size=args.wordembed_dim,
                          hidden_size=args.enc_hid_size).cuda()
        decoder = Decoder(embed_size=args.wordembed_dim,
                          hidden_size=args.dec_hid_size,
                          vocab_size=len(vocab_tgt)).cuda()

    #开始训练
    encoder.train()
    decoder.train()

    ## loss是交叉熵
    criterion = nn.CrossEntropyLoss(ignore_index=padding_idx).cuda()
    ## 优化器都是Adam
    dec_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   decoder.parameters()),
                                     lr=args.decoder_lr,
                                     weight_decay=args.weight_decay)
    enc_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   encoder.parameters()),
                                     lr=args.encoder_lr,
                                     weight_decay=args.weight_decay)

    count_paras(encoder, decoder, logging)  #这里会打印一下总参数量

    ## 存loss
    total_train_loss, total_val_loss = [], []
    best_val_bleu, best_epoch = 0, 0

    ## 初始化时间
    zero_time = time.time()

    # 开始整个训练过程
    earlystop_flag = False  #是否早停
    rising_count = 0

    for epoch in range(init_epoch, args.epochs):
        ## 开始按epoch迭代
        start_time = time.time()
        train_loss = train(train_loader, encoder, decoder, criterion,
                           enc_optimizer, dec_optimizer,
                           epoch)  #一个train周期,函数在下面

        val_loss, sentbleu, corpbleu = validate(val_loader, encoder, decoder,
                                                criterion)  #一个验证周期,函数在下面
        end_time = time.time()

        #记录时间
        epoch_time = end_time - start_time
        total_time = end_time - zero_time

        logging.info(
            'Total time used: %s Epoch %d time uesd: %s train loss: %.4f val loss: %.4f sentbleu: %.4f corpbleu: %.4f'
            % (str(datetime.timedelta(seconds=int(total_time))), epoch,
               str(datetime.timedelta(seconds=int(epoch_time))), train_loss,
               val_loss, sentbleu, corpbleu))

        if corpbleu > best_val_bleu:  #更新最好的结果
            best_val_bleu = corpbleu
            save_checkpoint(
                {
                    'epoch': epoch,
                    'enc_state_dict': encoder.state_dict(),
                    'dec_state_dict': decoder.state_dict(),
                    'enc_optimizer': enc_optimizer.state_dict(),
                    'dec_optimizer': dec_optimizer.state_dict(),
                }, cp_file)
            best_epoch = epoch

        logging.info("Finished {0} epochs of training".format(epoch +
                                                              1))  #写log

        #存loss
        total_train_loss.append(train_loss)
        total_val_loss.append(val_loss)

    logging.info('Best corpus bleu score {:.4f} at epoch {}'.format(
        best_val_bleu, best_epoch))  #写log

    ### 最好效果的模型会被存起来,之后test的时候可以用
    logging.info('************ Start eval... ************')
    eval(test_loader, encoder, decoder, cp_file, tok_tgt, result_path)
예제 #25
0
        args.start_epoch = checkpoint['epoch']
        #   encoder.load_state_dict(checkpoint['encoder'])
        #  decoder.load_state_dict(checkpoint['decoder'])
        classifier.load_state_dict(checkpoint['classifier'])

    ###############################################################################
    # Training code
    ###############################################################################

    for epoch in range(args.epochs):
        epoch_start_time = time.time()
        loss_z = AverageMeter()
        loss_recon = AverageMeter()
        loss_classify = AverageMeter()
        loss = AverageMeter()
        encoder.train()
        decoder.train()
        classifier.train()
        for batch_idx, data in enumerate(train_loader):
            optimizer.zero_grad()
            batch_start_time = time.time()
            img_1 = data[0].cuda()
            img_2 = data[1].cuda()
            img_1_atts = data[2].cuda()
            img_2_atts = data[3].cuda()

            z_1 = encoder(img_1)
            z_2 = encoder(img_2)
            img_2_trans = decoder(z_1, img_2_atts)
            img_1_trans = decoder(z_2, img_1_atts)
            img_1_recon = decoder(z_1, img_1_atts)
예제 #26
0
파일: main.py 프로젝트: ssumin6/transformer
def main(args):
    src, tgt = load_data(args.path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    N = 6
    dim = 512

    # MODEL Construction
    encoder = Encoder(N, dim, pad_idx, src_vocab_size, device).to(device)
    decoder = Decoder(N, dim, pad_idx, tgt_vocab_size, device).to(device)

    if args.model_load:
        ckpt = torch.load("drive/My Drive/checkpoint/best.ckpt")
        encoder.load_state_dict(ckpt["encoder"])
        decoder.load_state_dict(ckpt["decoder"])

    params = list(encoder.parameters()) + list(decoder.parameters())

    if not args.test:
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        warmup = 4000
        steps = 1
        lr = 1. * (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5))
        optimizer = torch.optim.Adam(params,
                                     lr=lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-09)

        train_losses = []
        val_losses = []
        latest = 1e08  # to store latest checkpoint

        start_epoch = 0

        if (args.model_load):
            start_epoch = ckpt["epoch"]
            optimizer.load_state_dict(ckpt["optim"])
            steps = start_epoch * 30

        for epoch in range(start_epoch, args.epochs):

            for src_batch, tgt_batch in train_loader:
                encoder.train()
                decoder.train()
                optimizer.zero_grad()
                tgt_batch = torch.LongTensor(tgt_batch)

                src_batch = Variable(torch.LongTensor(src_batch)).to(device)
                gt = Variable(tgt_batch[:, 1:]).to(device)
                tgt_batch = Variable(tgt_batch[:, :-1]).to(device)

                enc_output, seq_mask = encoder(src_batch)
                dec_output = decoder(tgt_batch, enc_output, seq_mask)

                gt = gt.view(-1)
                dec_output = dec_output.view(gt.size()[0], -1)

                loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx)
                loss.backward()
                train_losses.append(loss.item())
                optimizer.step()

                steps += 1
                lr = (dim**-0.5) * min(steps**-0.5, steps * (warmup**-1.5))
                update_lr(optimizer, lr)

                if (steps % 10 == 0):
                    print("loss : %f" % loss.item())

            for src_batch, tgt_batch in valid_loader:
                encoder.eval()
                decoder.eval()

                src_batch = Variable(torch.LongTensor(src_batch)).to(device)
                tgt_batch = torch.LongTensor(tgt_batch)
                gt = Variable(tgt_batch[:, 1:]).to(device)
                tgt_batch = Variable(tgt_batch[:, :-1]).to(device)

                enc_output, seq_mask = encoder(src_batch)
                dec_output = decoder(tgt_batch, enc_output, seq_mask)

                gt = gt.view(-1)
                dec_output = dec_output.view(gt.size()[0], -1)

                loss = F.cross_entropy(dec_output, gt, ignore_index=pad_idx)

                val_losses.append(loss.item())
            print("[EPOCH %d] Loss %f" % (epoch, loss.item()))

            if (val_losses[-1] <= latest):
                checkpoint = {'encoder':encoder.state_dict(), 'decoder':decoder.state_dict(), \
                    'optim':optimizer.state_dict(), 'epoch':epoch}
                torch.save(checkpoint, "drive/My Drive/checkpoint/best.ckpt")
                latest = val_losses[-1]

            if (epoch % 20 == 0):
                plt.figure()
                plt.plot(val_losses)
                plt.xlabel("epoch")
                plt.ylabel("model loss")
                plt.show()

    else:
        # test
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        # LOAD CHECKPOINT

        pred = []
        for src_batch, tgt_batch in test_loader:
            encoder.eval()
            decoder.eval()

            b_s = min(args.batch_size, len(src_batch))
            tgt_batch = torch.zeros(b_s, 1).to(device).long()
            src_batch = Variable(torch.LongTensor(src_batch)).to(device)

            enc_output, seq_mask = encoder(src_batch)
            pred_batch = decoder(tgt_batch, enc_output, seq_mask)
            _, pred_batch = torch.max(pred_batch, 2)

            while (not is_finished(pred_batch, max_length, eos_idx)):
                # do something
                next_input = torch.cat((tgt_batch, pred_batch.long()), 1)
                pred_batch = decoder(next_input, enc_output, seq_mask)
                _, pred_batch = torch.max(pred_batch, 2)
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred_batch = pred_batch.tolist()
            for line in pred_batch:
                line[-1] = 1
            pred += seq2sen(pred_batch, tgt_vocab)
            # print(pred)

        with open('results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
예제 #27
0
def train(args, logger):
    task_time = time.strftime("%Y-%m-%d %H:%M", time.localtime())
    Path("./saved_models/").mkdir(parents=True, exist_ok=True)
    Path("./pretrained_models/").mkdir(parents=True, exist_ok=True)
    MODEL_SAVE_PATH = './saved_models/'
    Pretrained_MODEL_PATH = './pretrained_models/'
    get_model_name = lambda part: f'{part}-{args.data}-{args.tasks}-{args.prefix}.pth'
    get_pretrain_model_name = lambda part: f'{part}-{args.data}-LP-{args.prefix}.pth'
    device_string = 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu >=0 else 'cpu'
    print('Model trainging with '+device_string)
    device = torch.device(device_string)
    


    g = load_graphs(f"./data/{args.data}.dgl")[0][0]
    
    efeat_dim = g.edata['feat'].shape[1]
    nfeat_dim = efeat_dim


    train_loader, val_loader, test_loader, num_val_samples, num_test_samples = dataloader(args, g)


    encoder = Encoder(args, nfeat_dim, n_head=args.n_head, dropout=args.dropout).to(device)
    decoder = Decoder(args, nfeat_dim).to(device)
    msg2mail = Msg2Mail(args, nfeat_dim)
    fraud_sampler = frauder_sampler(g)

    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    scheduler_lr = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40)
    if args.warmup:
        scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=3, after_scheduler=scheduler_lr)
        optimizer.zero_grad()
        optimizer.step()
    loss_fcn = torch.nn.BCEWithLogitsLoss()

    loss_fcn = loss_fcn.to(device)

    early_stopper = EarlyStopMonitor(logger=logger, max_round=args.patience, higher_better=True)

    if args.pretrain:
        logger.info(f'Loading the linkpred pretrained attention based encoder model')
        encoder.load_state_dict(torch.load(Pretrained_MODEL_PATH+get_pretrain_model_name('Encoder')))

    for epoch in range(args.n_epoch):
        # reset node state
        g.ndata['mail'] = torch.zeros((g.num_nodes(), args.n_mail, nfeat_dim+2), dtype=torch.float32) 
        g.ndata['feat'] = torch.zeros((g.num_nodes(), nfeat_dim), dtype=torch.float32) # init as zero, people can init it using others.
        g.ndata['last_update'] = torch.zeros((g.num_nodes()), dtype=torch.float32) 
        encoder.train()
        decoder.train()
        start_epoch = time.time()
        m_loss = []
        logger.info('start {} epoch, current optim lr is {}'.format(epoch, optimizer.param_groups[0]['lr']))
        for batch_idx, (input_nodes, pos_graph, neg_graph, blocks, frontier, current_ts) in enumerate(train_loader):
            

            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device) if neg_graph is not None else None
            

            if not args.no_time or not args.no_pos:
                current_ts, pos_ts, num_pos_nodes = get_current_ts(args, pos_graph, neg_graph)
                pos_graph.ndata['ts'] = current_ts
            else:
                current_ts, pos_ts, num_pos_nodes = None, None, None
            
            _ = dgl.add_reverse_edges(neg_graph) if neg_graph is not None else None
            emb, _ = encoder(dgl.add_reverse_edges(pos_graph), _, num_pos_nodes)
            if batch_idx != 0:
                if 'LP' not in args.tasks and args.balance:
                    neg_graph = fraud_sampler.sample_fraud_event(g, args.bs//5, current_ts.max().cpu()).to(device)
                logits, labels = decoder(emb, pos_graph, neg_graph)

                loss = loss_fcn(logits, labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                m_loss.append(loss.item())


            # MSG Passing
            with torch.no_grad():
                mail = msg2mail.gen_mail(args, emb, input_nodes, pos_graph, frontier, 'train')

                if not args.no_time:
                    g.ndata['last_update'][pos_graph.ndata[dgl.NID][:num_pos_nodes]] = pos_ts.to('cpu')
                g.ndata['feat'][pos_graph.ndata[dgl.NID]] = emb.to('cpu')
                g.ndata['mail'][input_nodes] = mail
            if batch_idx % 100 == 1:
                gpu_mem = torch.cuda.max_memory_allocated() / 1.074e9 if torch.cuda.is_available() and args.gpu >= 0 else 0
                torch.cuda.empty_cache()
                mem_perc = psutil.virtual_memory().percent
                cpu_perc = psutil.cpu_percent(interval=None)
                output_string = f'Epoch {epoch} | Step {batch_idx}/{len(train_loader)} | CPU {cpu_perc:.1f}% | Sys Mem {mem_perc:.1f}% | GPU Mem {gpu_mem:.4f}GB '
                
                output_string += f'| {args.tasks} Loss {np.mean(m_loss):.4f}'

                logger.info(output_string)

        total_epoch_time = time.time() - start_epoch
        logger.info(' training epoch: {} took {:.4f}s'.format(epoch, total_epoch_time))
        val_ap, val_auc, val_acc, val_loss = eval_epoch(args, logger, g, val_loader, encoder, decoder, msg2mail, loss_fcn, device, num_val_samples)
        logger.info('Val {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, val_ap, val_auc, val_acc, val_loss))

        if args.warmup:
            scheduler_warmup.step(epoch)
        else:
            scheduler_lr.step()

        early_stopper_metric = val_ap if 'LP' in args.tasks else val_auc

        if early_stopper.early_stop_check(early_stopper_metric):
            logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            encoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Encoder')))
            decoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Decoder')))

            test_result = [early_stopper.best_ap, early_stopper.best_auc, early_stopper.best_acc, early_stopper.best_loss]
            break

        test_ap, test_auc, test_acc, test_loss = eval_epoch(args, logger, g, test_loader, encoder, decoder, msg2mail, loss_fcn, device, num_test_samples)
        logger.info('Test {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, test_ap, test_auc, test_acc, test_loss))
        test_result = [test_ap, test_auc, test_acc, test_loss]

        if early_stopper.best_epoch == epoch: 
            early_stopper.best_ap = test_ap
            early_stopper.best_auc = test_auc
            early_stopper.best_acc = test_acc
            early_stopper.best_loss = test_loss
            logger.info(f'Saving the best model at epoch {early_stopper.best_epoch}')
            torch.save(encoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Encoder'))
            torch.save(decoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Decoder'))
예제 #28
0
def train(config):
    train_config = config['train']

    global device
    device = train_config['device']
    if not torch.cuda.is_available(): device = 'cpu'
    tqdm.write('Training on {}'.format(device))
    writer = SummaryWriter('log')

    train_dataset, test_dataset = create_datasets(**config['dataset'])

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['batch_size'],
                                  shuffle=True,
                                  collate_fn=collate_fn)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=train_config['batch_size'],
                                 shuffle=False,
                                 collate_fn=collate_fn)

    encoder = Encoder(vocab_size=len(train_dataset.lang1),
                      **config['encoder'],
                      device=device).to(device)
    decoder = Decoder(vocab_size=len(train_dataset.lang2),
                      **config['decoder']).to(device)

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=train_config['lr'])
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=train_config['lr'])

    criterion = nn.NLLLoss()

    tqdm.write('[-] Start training! ')
    epoch_bar = tqdm(range(train_config['n_epochs']),
                     desc='[Total progress]',
                     leave=True,
                     position=0,
                     dynamic_ncols=True)
    for epoch in epoch_bar:
        batch_bar = tqdm(range(len(train_dataloader)),
                         desc='[Train epoch {:2}]'.format(epoch),
                         leave=True,
                         position=0,
                         dynamic_ncols=True)
        encoder.train()
        decoder.train()
        train_loss = 0
        for batch in batch_bar:
            (source, target_bos, target_eos) = next(iter(train_dataloader))
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            source, target_bos, target_eos = source.to(device), target_bos.to(
                device), target_eos.to(device)
            encoder_output, encoder_hidden = encoder(source)
            decoder_output = decoder(target_bos, encoder_hidden)

            loss = criterion(decoder_output.view(-1, decoder_output.size(-1)),
                             target_eos.view(-1))
            train_loss += loss.item()
            n_hit, n_total = hitRate(decoder_output, target_eos)
            loss.backward()
            #print(loss.item())

            encoder_optimizer.step()
            decoder_optimizer.step()

            batch_bar.set_description(
                '[Train epoch {:2} | Loss: {:.2f} | Hit: {}/{}]'.format(
                    epoch, loss, n_hit, n_total))
        train_loss /= len(train_dataloader)

        batch_bar = tqdm(range(len(test_dataloader)),
                         desc='[Test epoch {:2}]'.format(epoch),
                         leave=True,
                         position=0,
                         dynamic_ncols=True)
        encoder.eval()
        decoder.eval()
        test_loss = 0
        for batch in batch_bar:
            (source, target_bos, target_eos) = next(iter(test_dataloader))
            source, target_bos, target_eos = source.to(device), target_bos.to(
                device), target_eos.to(device)

            with torch.no_grad():
                encoder_output, encoder_hidden = encoder(source)
                decoder_output = decoder(target_bos, encoder_hidden)
                loss = criterion(
                    decoder_output.view(-1, decoder_output.size(-1)),
                    target_eos.view(-1))
                test_loss += loss.item()
                n_hit, n_total = hitRate(decoder_output, target_eos)
                batch_bar.set_description(
                    '[Test epoch {:2} | Loss: {:.2f} | Hit: {}/{}]'.format(
                        epoch, loss, n_hit, n_total))

        test_loss /= len(test_dataloader)
        writer.add_scalars('Loss', {
            'train': train_loss,
            'test': test_loss
        }, epoch)
        sample(test_dataset, encoder, decoder)

    tqdm.write('[-] Done!')
예제 #29
0
def main(args):

    #create a writer
    writer = SummaryWriter('loss_plot_' + args.mode, comment='test')
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    val_length = len(os.listdir(args.image_dir_val))

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    data_loader_val = get_loader(args.image_dir_val,
                                 args.caption_path_val,
                                 vocab,
                                 transform,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers)

    # Build the model
    # if no-attention model is chosen:
    if args.model_type == 'no_attention':
        encoder = Encoder(args.embed_size).to(device)
        decoder = Decoder(args.embed_size, args.hidden_size, len(vocab),
                          args.num_layers).to(device)
        criterion = nn.CrossEntropyLoss()

    # if attention model is chosen:
    elif args.model_type == 'attention':
        encoder = EncoderAtt(encoded_image_size=9).to(device)
        decoder = DecoderAtt(vocab, args.encoder_dim, args.hidden_size,
                             args.attention_dim, args.embed_size,
                             args.dropout_ratio, args.alpha_c).to(device)

    # if transformer model is chosen:
    elif args.model_type == 'transformer':
        model = Transformer(len(vocab), args.embed_size,
                            args.transformer_layers, 8,
                            args.dropout_ratio).to(device)

        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, model.encoder.parameters()),
                                             lr=args.learning_rate_enc)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, model.decoder.parameters()),
                                             lr=args.learning_rate_dec)
        criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])

    else:
        print('Select model_type attention or no_attention')

    # if model is not transformer: additional step in encoder is needed: freeze lower layers of resnet if args.fine_tune == True
    if args.model_type != 'transformer':
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.learning_rate_dec)
        encoder.fine_tune(args.fine_tune)
        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, encoder.parameters()),
                                             lr=args.learning_rate_enc)

    # initialize lists to store results:
    loss_train = []
    loss_val = []
    loss_val_epoch = []
    loss_train_epoch = []

    bleu_res_list = []
    cider_res_list = []
    rouge_res_list = []

    results = {}

    # calculate total steps fot train and validation
    total_step = len(data_loader)
    total_step_val = len(data_loader_val)

    #For each epoch
    for epoch in tqdm(range(args.num_epochs)):

        loss_val_iter = []
        loss_train_iter = []

        # set model to train mode
        if args.model_type != 'transformer':
            encoder.train()
            decoder.train()
        else:
            model.train()

        # for each entry in data_loader
        for i, (images, captions, lengths) in tqdm(enumerate(data_loader)):
            # load images and captions to device
            images = images.to(device)
            captions = captions.to(device)
            # Forward, backward and optimize

            # forward and backward path is different dependent of model type:
            if args.model_type == 'no_attention':
                # get features from encoder
                features = encoder(images)
                # pad targergets to a length
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                # get output from decoder
                outputs = decoder(features, captions, lengths)
                # calculate loss
                loss = criterion(outputs, targets)

                # optimizer and backward step
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            elif args.model_type == 'attention':

                # get features from encoder
                features = encoder(images)

                # get targets - starting from 2 word in captions
                #(the model not sequantial, so targets are predicted in parallel- no need to predict first word in captions)

                targets = captions[:, 1:]
                # decode length = length-1 for each caption
                decode_lengths = [length - 1 for length in lengths]
                #flatten targets
                targets = targets.reshape(targets.shape[0] * targets.shape[1])

                sampled_caption = []

                # get scores and alphas from decoder
                scores, alphas = decoder(features, captions, decode_lengths)

                scores = scores.view(-1, scores.shape[-1])

                #predicted = prediction with maximum score
                _, predicted = torch.max(scores, dim=1)

                # calculate loss
                loss = decoder.loss(scores, targets, alphas)

                # optimizer and backward step
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            elif args.model_type == 'transformer':

                # input is captions without last word
                trg_input = captions[:, :-1]
                # create mask
                trg_mask = create_masks(trg_input)

                # get scores from model
                scores = model(images, trg_input, trg_mask)
                scores = scores.view(-1, scores.shape[-1])

                # get targets - starting from 2 word in captions
                targets = captions[:, 1:]

                #predicted = prediction with maximum score
                _, predicted = torch.max(scores, dim=1)

                # calculate loss
                loss = criterion(
                    scores,
                    targets.reshape(targets.shape[0] * targets.shape[1]))

                #forward and backward path
                decoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()

            else:
                print('Select model_type attention or no_attention')

            # append results to loss lists and writer
            loss_train_iter.append(loss.item())
            loss_train.append(loss.item())
            writer.add_scalar('Loss/train/iterations', loss.item(), i + 1)

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))

        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'.
              format(epoch, args.num_epochs, i, total_step, loss.item(),
                     np.exp(loss.item())))

        #append mean of last 10 batches as approximate epoch loss
        loss_train_epoch.append(np.mean(loss_train_iter[-10:]))

        writer.add_scalar('Loss/train/epoch', np.mean(loss_train_iter[-10:]),
                          epoch + 1)

        #save model
        if args.model_type != 'transformer':
            torch.save(
                decoder.state_dict(),
                os.path.join(
                    args.model_path,
                    'decoder_' + args.mode + '_{}.ckpt'.format(epoch + 1)))
            torch.save(
                encoder.state_dict(),
                os.path.join(
                    args.model_path,
                    'decoder_' + args.mode + '_{}.ckpt'.format(epoch + 1)))

        else:
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.model_path,
                    'model_' + args.mode + '_{}.ckpt'.format(epoch + 1)))
        np.save(
            os.path.join(args.predict_json,
                         'loss_train_temp_' + args.mode + '.npy'), loss_train)

        #validate model:
        # set model to eval mode:
        if args.model_type != 'transformer':
            encoder.eval()
            decoder.eval()
        else:
            model.eval()
        total_step = len(data_loader_val)

        # set no_grad mode:
        with torch.no_grad():
            # for each entry in data_loader
            for i, (images, captions,
                    lengths) in tqdm(enumerate(data_loader_val)):
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]
                images = images.to(device)
                captions = captions.to(device)

                # forward and backward path is different dependent of model type:
                if args.model_type == 'no_attention':
                    features = encoder(images)
                    outputs = decoder(features, captions, lengths)
                    loss = criterion(outputs, targets)

                elif args.model_type == 'attention':

                    features = encoder(images)
                    sampled_caption = []
                    targets = captions[:, 1:]
                    decode_lengths = [length - 1 for length in lengths]
                    targets = targets.reshape(targets.shape[0] *
                                              targets.shape[1])

                    scores, alphas = decoder(features, captions,
                                             decode_lengths)

                    _, predicted = torch.max(scores, dim=1)

                    scores = scores.view(-1, scores.shape[-1])

                    sampled_caption = []

                    loss = decoder.loss(scores, targets, alphas)

                elif args.model_type == 'transformer':

                    trg_input = captions[:, :-1]
                    trg_mask = create_masks(trg_input)
                    scores = model(images, trg_input, trg_mask)
                    scores = scores.view(-1, scores.shape[-1])
                    targets = captions[:, 1:]

                    _, predicted = torch.max(scores, dim=1)

                    loss = criterion(
                        scores,
                        targets.reshape(targets.shape[0] * targets.shape[1]))

                #display results
                if i % args.log_step == 0:
                    print(
                        'Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Validation Perplexity: {:5.4f}'
                        .format(epoch, args.num_epochs, i, total_step_val,
                                loss.item(), np.exp(loss.item())))

                # append results to loss lists and writer
                loss_val.append(loss.item())
                loss_val_iter.append(loss.item())

                writer.add_scalar('Loss/validation/iterations', loss.item(),
                                  i + 1)

        np.save(
            os.path.join(args.predict_json, 'loss_val_' + args.mode + '.npy'),
            loss_val)

        print(
            'Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Validation Perplexity: {:5.4f}'
            .format(epoch, args.num_epochs, i, total_step_val, loss.item(),
                    np.exp(loss.item())))

        # results: epoch validation loss

        loss_val_epoch.append(np.mean(loss_val_iter))
        writer.add_scalar('Loss/validation/epoch', np.mean(loss_val_epoch),
                          epoch + 1)

        #predict captions:
        filenames = os.listdir(args.image_dir_val)

        predicted = {}

        for file in tqdm(filenames):
            if file == '.DS_Store':
                continue
            # Prepare an image
            image = load_image(os.path.join(args.image_dir_val, file),
                               transform)
            image_tensor = image.to(device)

            # Generate caption starting with <start> word

            # procedure is different for each model type
            if args.model_type == 'attention':

                features = encoder(image_tensor)
                sampled_ids, _ = decoder.sample(features)
                sampled_ids = sampled_ids[0].cpu().numpy()
                #start sampled_caption with <start>
                sampled_caption = ['<start>']

            elif args.model_type == 'no_attention':
                features = encoder(image_tensor)
                sampled_ids = decoder.sample(features)
                sampled_ids = sampled_ids[0].cpu().numpy()
                sampled_caption = ['<start>']

            elif args.model_type == 'transformer':

                e_outputs = model.encoder(image_tensor)
                max_seq_length = 20
                sampled_ids = torch.zeros(max_seq_length, dtype=torch.long)
                sampled_ids[0] = torch.LongTensor([[vocab.word2idx['<start>']]
                                                   ]).to(device)

                for i in range(1, max_seq_length):

                    trg_mask = np.triu(np.ones((1, i, i)), k=1).astype('uint8')
                    trg_mask = Variable(
                        torch.from_numpy(trg_mask) == 0).to(device)

                    out = model.decoder(sampled_ids[:i].unsqueeze(0),
                                        e_outputs, trg_mask)

                    out = model.out(out)
                    out = F.softmax(out, dim=-1)
                    val, ix = out[:, -1].data.topk(1)
                    sampled_ids[i] = ix[0][0]

                sampled_ids = sampled_ids.cpu().numpy()
                sampled_caption = []

            # Convert word_ids to words
            for word_id in sampled_ids:
                word = vocab.idx2word[word_id]
                sampled_caption.append(word)
                # break at <end> of the sentence
                if word == '<end>':
                    break
            sentence = ' '.join(sampled_caption)

            predicted[file] = sentence

        # save predictions to json file:
        json.dump(
            predicted,
            open(
                os.path.join(
                    args.predict_json,
                    'predicted_' + args.mode + '_' + str(epoch) + '.json'),
                'w'))

        #validate model
        with open(args.caption_path_val, 'r') as file:
            captions = json.load(file)

        res = {}
        for r in predicted:
            res[r] = [predicted[r].strip('<start> ').strip(' <end>')]

        images = captions['images']
        caps = captions['annotations']
        gts = {}
        for image in images:
            image_id = image['id']
            file_name = image['file_name']
            list_cap = []
            for cap in caps:
                if cap['image_id'] == image_id:
                    list_cap.append(cap['caption'])
            gts[file_name] = list_cap

        #calculate BLUE, CIDER and ROUGE metrics from real and resulting captions
        bleu_res = bleu(gts, res)
        cider_res = cider(gts, res)
        rouge_res = rouge(gts, res)

        # append resuls to result lists
        bleu_res_list.append(bleu_res)
        cider_res_list.append(cider_res)
        rouge_res_list.append(rouge_res)

        # write results to writer
        writer.add_scalar('BLEU1/validation/epoch', bleu_res[0], epoch + 1)
        writer.add_scalar('BLEU2/validation/epoch', bleu_res[1], epoch + 1)
        writer.add_scalar('BLEU3/validation/epoch', bleu_res[2], epoch + 1)
        writer.add_scalar('BLEU4/validation/epoch', bleu_res[3], epoch + 1)
        writer.add_scalar('CIDEr/validation/epoch', cider_res, epoch + 1)
        writer.add_scalar('ROUGE/validation/epoch', rouge_res, epoch + 1)

    results['bleu'] = bleu_res_list
    results['cider'] = cider_res_list
    results['rouge'] = rouge_res_list

    json.dump(
        results,
        open(os.path.join(args.predict_json, 'results_' + args.mode + '.json'),
             'w'))
    np.save(
        os.path.join(args.predict_json, 'loss_train_' + args.mode + '.npy'),
        loss_train)
    np.save(os.path.join(args.predict_json, 'loss_val_' + args.mode + '.npy'),
            loss_val)
def main():
    # Training settings
    list_of_choices = ['forbenius', 'cosine_squared', 'cosine_abs']

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch  rotation test (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=20,
                        metavar='N',
                        help='number of epochs to train (default: 20)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        metavar='LR',
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--store-interval',
        type=int,
        default=50,
        metavar='N',
        help='how many batches to wait before storing training loss')
    parser.add_argument(
        '--name',
        type=str,
        default='',
        help='name of the run that is added to the output directory')
    parser.add_argument(
        "--loss",
        dest='loss',
        default='forbenius',
        choices=list_of_choices,
        help=
        'Decide type of loss, (forbenius) norm, difference of (cosine), (default=forbenius)'
    )
    parser.add_argument(
        '--init-rot-range',
        type=float,
        default=360,
        help=
        'Upper bound of range in degrees of initial random rotation of digits, (Default=360)'
    )
    parser.add_argument('--relative-rot-range',
                        type=float,
                        default=90,
                        metavar='theta',
                        help='Relative rotation range (-theta, theta)')
    parser.add_argument('--eval-batch-size',
                        type=int,
                        default=200,
                        metavar='N',
                        help='batch-size for evaluation')

    args = parser.parse_args()

    #Print arguments
    for arg in vars(args):
        sys.stdout.write('{} = {} \n'.format(arg, getattr(args, arg)))
        sys.stdout.flush()

    sys.stdout.write('Random torch seed:{}\n'.format(torch.initial_seed()))
    sys.stdout.flush()

    args.init_rot_range = args.init_rot_range * np.pi / 180
    args.relative_rot_range = args.relative_rot_range * np.pi / 180
    # Create save path

    path = "./output_" + args.name
    if not os.path.exists(path):
        os.makedirs(path)

    sys.stdout.write('Start training\n')
    sys.stdout.flush()

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    writer = SummaryWriter(path, comment='Encoder atan2 MNIST')
    # Set up dataloaders
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    train_loader_eval = torch.utils.data.DataLoader(
        datasets.MNIST('../data',
                       train=True,
                       transform=transforms.Compose([transforms.ToTensor()])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)

    # Init model and optimizer
    model = Encoder(device).to(device)

    #Initialise weights
    model.apply(weights_init)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #Init losses log

    prediction_mean_error = []  #Average  rotation prediction error in degrees
    prediction_error_std = []  #Std of error for rotation prediciton
    train_loss = []

    #Train
    n_iter = 0
    for epoch in range(1, args.epochs + 1):
        sys.stdout.write('Epoch {}/{} \n '.format(epoch, args.epochs))
        sys.stdout.flush()

        for batch_idx, (data, targets) in enumerate(train_loader):
            model.train()
            # Reshape data
            data, targets, angles = rotate_tensor(data.numpy(),
                                                  args.init_rot_range,
                                                  args.relative_rot_range)
            data = torch.from_numpy(data).to(device)
            targets = torch.from_numpy(targets).to(device)
            angles = torch.from_numpy(angles).to(device)
            angles = angles.view(angles.size(0), 1)

            # Forward passes
            optimizer.zero_grad()
            f_data = model(data)  # [N,2,1,1]
            f_targets = model(targets)  #[N,2,1,1]

            #Apply rotatin matrix to f_data with feature transformer
            f_data_trasformed = feature_transformer(f_data, angles, device)

            #Define loss

            loss = define_loss(args, f_data_trasformed, f_targets)

            # Backprop
            loss.backward()
            optimizer.step()

            #Log progress
            if batch_idx % args.log_interval == 0:
                sys.stdout.write(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\r'.format(
                        epoch, batch_idx * len(data),
                        len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss))
                sys.stdout.flush()

                writer.add_scalar('Training Loss', loss, n_iter)

            #Store training and test loss
            if batch_idx % args.store_interval == 0:
                #Train Loss
                train_loss.append(
                    evaluate_model(args, model, device, train_loader_eval))

                #Rotation loss in trainign set
                mean, std = rotation_test(args, model, device,
                                          train_loader_eval)
                prediction_mean_error.append(mean)
                writer.add_scalar('Mean test error', mean, n_iter)

                prediction_error_std.append(std)

            n_iter += 1

        save_model(args, model)

    #Save model

    #Save losses
    train_loss = np.array(train_loss)
    prediction_mean_error = np.array(prediction_mean_error)
    prediction_error_std = np.array(prediction_error_std)

    np.save(path + '/training_loss', train_loss)
    np.save(path + '/prediction_mean_error', prediction_mean_error)
    np.save(path + '/prediction_error_std', prediction_error_std)

    plot_learning_curve(args, train_loss, prediction_mean_error,
                        prediction_error_std, path)

    #Get diagnostics per digit
    get_error_per_digit(args, model, device)