Esempio n. 1
0
def train(train_iter, test_iter, net, feature_params, loss, device, num_epochs,
          file_name):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    best_test_acc = 0
    lr = 0.001
    optimizer = Ranger([{
        'params': feature_params
    }, {
        'params': net.fc.parameters(),
        'lr': lr * 10
    }],
                       lr=lr,
                       weight_decay=0.0001)
    # optimizer = optim.SGD([{'params': feature_params},
    #                        {'params': net.fc.parameters(), 'lr': lr * 10}],
    #                       lr=lr, weight_decay=0.001)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLr(optimizer, T_max=5, eta_min=4e-08)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=10,
                                                gamma=0.1)
    for epoch in range(1, num_epochs + 1):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        scheduler.step()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print(
            'epoch %d, loss %.5f, train_acc %.5f, val_acc %.5f, time %.1f sec'
            % (epoch, train_l_sum / batch_count, train_acc_sum / n, test_acc,
               time.time() - start))
        if test_acc > best_test_acc:
            print('find best! save at model/%s/best.pth' % file_name)
            best_test_acc = test_acc
            torch.save(net.state_dict(), './model/%s/best.pth' % file_name)
            with open('./result/%s.txt' % file_name, 'a') as acc_file:
                acc_file.write('Epoch: %2d, acc: %.8f\n' % (epoch, test_acc))
        if epoch % 10 == 0:
            torch.save(net.state_dict(),
                       './model/%s/checkpoint_%d.pth' % (file_name, epoch))
Esempio n. 2
0
def train_fold():
    #get arguments
    opts = get_args()

    #gpu selection
    os.environ["CUDA_VISIBLE_DEVICES"] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #instantiate datasets
    json_path = os.path.join(opts.path, 'train.json')
    json = pd.read_json(json_path, lines=True)
    train_ids = json.id.to_list()

    json_path = os.path.join(opts.path, 'test.json')
    test = pd.read_json(json_path, lines=True)

    #aug_test=test
    #dataloader
    ls_indices = test.seq_length == 130
    long_data = test[ls_indices]
    ids = np.asarray(long_data.id.to_list())
    long_dataset = RNADataset(long_data.sequence.to_list(),
                              np.zeros(len(ls_indices)), ids,
                              np.arange(len(ls_indices)), opts.path)
    long_dataloader = DataLoader(long_dataset,
                                 batch_size=opts.batch_size,
                                 shuffle=True,
                                 num_workers=opts.workers)

    ss_indices = test.seq_length == 107
    short_data = test[ss_indices]
    ids = short_data.id.to_list()
    ids = ids + train_ids
    short_sequences = short_data.sequence.to_list() + json.sequence.to_list()
    short_dataset = RNADataset(short_sequences, np.zeros(len(short_sequences)),
                               ids, np.arange(len(short_sequences)), opts.path)
    short_dataloader = DataLoader(short_dataset,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=opts.workers)

    #checkpointing
    checkpoints_folder = 'pretrain_weights'
    csv_file = 'pretrain.csv'.format((opts.fold))
    columns = ['epoch', 'train_loss']
    logger = CSVLogger(columns, csv_file)

    #build model and logger
    model = NucleicTransformer(opts.ntoken,
                               opts.nclass,
                               opts.ninp,
                               opts.nhead,
                               opts.nhid,
                               opts.nlayers,
                               opts.kmer_aggregation,
                               kmers=opts.kmers,
                               stride=opts.stride,
                               dropout=opts.dropout,
                               pretrain=True).to(device)
    optimizer = Ranger(model.parameters(), weight_decay=opts.weight_decay)
    #optimizer=torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    #lr_schedule=lr_AIAYN(optimizer,opts.ninp,opts.warmup_steps,opts.lr_scale)

    # Mixed precision initialization
    opt_level = 'O1'
    #model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    model = nn.DataParallel(model)

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print('Total number of paramters: {}'.format(pytorch_total_params))

    #training loop
    cos_epoch = int(opts.epochs * 0.75)
    total_steps = len(long_dataloader) + len(short_dataloader)
    lr_schedule = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, (opts.epochs - cos_epoch) * (total_steps))
    for epoch in range(opts.epochs):
        model.train(True)
        t = time.time()
        total_loss = 0
        optimizer.zero_grad()
        train_preds = []
        ground_truths = []
        step = 0
        for data in short_dataloader:
            #for step in range(1):
            step += 1
            lr = get_lr(optimizer)
            src = data['data']
            labels = data['labels']
            bpps = data['bpp'].to(device)

            if np.random.uniform() > 0.5:
                masked = mutate_rna_input(src)
            else:
                masked = mask_rna_input(src)

            src = src.to(device).long()
            masked = masked.to(device).long()

            #labels=labels.to(device).float()
            output = model(masked, bpps)
            #ew=data['ew'].to(device)


            loss=(criterion(output[0].reshape(-1,4),src[:,:,0].reshape(-1))+\
            criterion(output[1].reshape(-1,3),src[:,:,1].reshape(-1)-4)+\
            criterion(output[2].reshape(-1,7),src[:,:,2].reshape(-1)-7))

            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss
            if epoch > cos_epoch:
                lr_schedule.step()
            print(
                "Epoch [{}/{}], Step [{}/{}] Loss: {:.3f} Lr:{:.6f} Time: {:.1f}"
                .format(epoch + 1, opts.epochs, step + 1, total_steps,
                        total_loss / (step + 1), lr,
                        time.time() - t),
                end='\r',
                flush=True)  #total_loss/(step+1)
        for data in long_dataloader:
            #for step in range(1):
            step += 1
            lr = get_lr(optimizer)
            src = data['data']
            labels = data['labels']
            bpps = data['bpp'].to(device)

            if np.random.uniform() > 0.5:
                masked = mutate_rna_input(src)
            else:
                masked = mask_rna_input(src)

            src = src.to(device).long()
            masked = masked.to(device).long()
            #labels=labels.to(device).float()
            output = model(masked, bpps)
            #ew=data['ew'].to(device)

            loss=(criterion(output[0].reshape(-1,4),src[:,:,0].reshape(-1))+\
            criterion(output[1].reshape(-1,3),src[:,:,1].reshape(-1)-4)+\
            criterion(output[2].reshape(-1,7),src[:,:,2].reshape(-1)-7))

            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss
            if epoch > cos_epoch:
                lr_schedule.step()
            print(
                "Epoch [{}/{}], Step [{}/{}] Loss: {:.3f} Lr:{:.6f} Time: {:.1f}"
                .format(epoch + 1, opts.epochs, step + 1, total_steps,
                        total_loss / (step + 1), lr,
                        time.time() - t),
                end='\r',
                flush=True)  #total_loss/(step+1)

            #break
            # if epoch > cos_epoch:
            #     lr_schedule.step()
        print('')
        train_loss = total_loss / (step + 1)
        torch.cuda.empty_cache()
        to_log = [
            epoch + 1,
            train_loss,
        ]
        logger.log(to_log)

        if (epoch + 1) % opts.save_freq == 0:
            save_weights(model, optimizer, epoch, checkpoints_folder)

    get_best_weights_from_fold(opts.fold)
Esempio n. 3
0
def train_discriminator_adv(dis, dis_option_encoder, state_encoder, reason_decoder, dis_optimizer,  epochs, max_length = 25):
	"""
	Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
	Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
	"""

	# generating a small validation set before training (using oracle and generator)
	dis_optimizer = Ranger(
		itertools.chain(dis_option_encoder.parameters(), dis.parameters()), lr=1e-3,
		weight_decay=1e-4)
	for epoch in tqdm(range(epochs)):
		sys.stdout.flush()
		total_loss = 0
		total_acc = 0
		count = 0
		for train_data in train_loader_adv:
			num, statement, reason, label, options = train_data
			batch_size, input_size = statement.size()
			state_input_lengths = [len(x) for x in statement]
			falselabel = [[0,1,2] for l in label]
			for i in range(len(label)): falselabel[i].remove(label[i])
			choice_label = [random.choice(ls) for ls in falselabel]

			input_lengths = [len(x) for x in statement]
			encoder_outputs, hidden = state_encoder(statement, input_lengths)

			decoder_input = torch.tensor([[word_encoder.encode("<sos>")] for i in range(batch_size)]).type(
				torch.cuda.LongTensor)

			encoder_outputs = encoder_outputs.permute(1, 0, 2)  # -> (T*B*H)

			decoder_outputs = [[] for i in range(batch_size)]
			eos_idx = word_encoder.encode("<eos>")
			for di in range(max_length):
				output, hidden = reason_decoder(decoder_input, hidden, encoder_outputs)
				topv, topi = output.topk(1)
				all_end = True
				for i in range(batch_size):
					# print(topi.squeeze()[i].item())
					if topi.squeeze()[i].item() != eos_idx:
						decoder_outputs[i].append(word_encoder.decode(np.array([topi.squeeze()[i].item()])))
						all_end = False
				if all_end:
					break
				decoder_input = topi.squeeze().detach()
			decoder_outputs = [" ".join(output) for output in decoder_outputs]

			for i in range(batch_size):
				options[choice_label[i]][i][1] = decoder_outputs[i]


			option_input_lengths = [[len(x[0]) + len(x[1]) for x in option] for option in options]
			option_lens = [max(len(option[0])+len(option[1])+1 for option in options[i]) for i in range(3)]
			options = [[word_encoder.encode("<sep>".join([x[0],x[1]])) for x in option] for option in options]
			options = [sequence.pad_sequences(option, maxlen=option_len, padding='post') for option, option_len in zip(options, option_lens) ]
			options = [torch.tensor(option).type(torch.cuda.LongTensor) for option in options]


			option_hiddens = []
			for i in range(3):
				encoder_outputs, option_hidden = dis_option_encoder(options[i], option_input_lengths[i])
				option_hiddens.append(option_hidden)

			dis_optimizer.zero_grad()
			out = dis(batch_size, option_hiddens)
			loss_fn = nn.CrossEntropyLoss()
			label = torch.tensor(label).type(torch.cuda.LongTensor)
			loss = loss_fn(out, label)
			loss.backward()
			dis_optimizer.step()
			total_loss += loss.data.item()
			total_acc += (out.argmax(1) == label).sum().item()

			sys.stdout.flush()
			count += 1

		print('\n average_loss = %.4f, train_acc = %.4f' % (
			total_loss / (count * BATCH_SIZE), total_acc / (count * BATCH_SIZE)))
Esempio n. 4
0
def train_alphaBert(DS_model,
                    dloader,
                    lr=1e-4,
                    epoch=10,
                    log_interval=20,
                    lkahead=False):
    global checkpoint_file
    DS_model.to(device)
    #    model_optimizer = optim.Adam(DS_model.parameters(), lr=lr)
    model_optimizer = Ranger(DS_model.parameters(), lr=lr)
    DS_model = torch.nn.DataParallel(DS_model)
    DS_model.train()
    #    if lkahead:
    #        print('using Lookahead')
    #        model_optimizer = lookahead_pytorch.Lookahead(model_optimizer, la_steps=5, la_alpha=0.5)
    #    model_optimizer = Ranger(DS_model.parameters(), lr=4e-3, alpha=0.5, k=5)
    #    criterion = nn.MSELoss().to(device)
    #    criterion = alphabert_loss_v02.Alphabert_loss(device=device)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    for ep in range(epoch):
        DS_model.train()

        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs = len(src)

            src = src.float().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            pred_prop, = DS_model(input_ids=src,
                                  attention_mask=att_mask,
                                  out='finehead')

            trg_view = trg.view(-1).contiguous()
            trg_mask0 = trg_view == 0
            trg_mask1 = trg_view == 1

            loss = criterion(pred_prop, trg_view)
            #            try:
            #                loss0 = criterion(pred_prop[trg_mask0],trg_view[trg_mask0])
            #                loss1 = criterion(pred_prop[trg_mask1],trg_view[trg_mask1])
            #
            #                loss += 0.2*loss0+0.8*loss1
            #            except:
            #                loss = criterion(pred_prop,trg.view(-1).contiguous())

            loss.backward()
            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

            if iteration % log_interval == 0:
                #                step_loss.backward()
                #                model_optimizer.step()
                #                print('+++ update +++')
                print(
                    'Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.format(
                        ep, batch_idx * batch_size,
                        100. * batch_idx / len(dloader), (time.time() - t0) *
                        len(dloader) / (60 * (batch_idx + 1)), loss.item()))
#                print(0,st_target)
#                step_loss = 0

            if iteration % 400 == 0:
                save_checkpoint(checkpoint_file,
                                'd2s_total.pth',
                                DS_model,
                                model_optimizer,
                                parallel=parallel)
                print(
                    tokenize_alphabets.convert_idx2str(src[0][:origin_len[0]]))
            iteration += 1
        if ep % 1 == 0:
            save_checkpoint(checkpoint_file,
                            'd2s_total.pth',
                            DS_model,
                            model_optimizer,
                            parallel=parallel)
            #            test_alphaBert(DS_model,D2S_valloader,
            #                           is_clean_up=True, ep=ep,train=True)

            print('======= epoch:%i ========' % ep)


#        print('total loss: {:.4f}'.format(total_loss/len(dloader)))
        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        #        total_loss.append(epoch_loss)
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./iou_pic/total_loss_finetune.csv', sep=',')
    print(total_loss)
Esempio n. 5
0
def train(args):
    # get configs
    epochs = args.epoch
    dim = args.dim
    lr = args.lr
    weight_decay = args.l2
    head_num = args.head_num
    device = args.device
    act = args.act
    fusion = args.fusion
    beta = args.beta
    alpha = args.alpha
    use_self = args.use_self
    agg = args.agg
    model = DATE(leaf_num,importer_size,item_size,\
                                    dim,head_num,\
                                    fusion_type=fusion,act=act,device=device,\
                                    use_self=use_self,agg_type=agg,
                                    ).to(device)
    model = nn.DataParallel(model,device_ids=[0,1])

    # initialize parameters
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    # optimizer & loss 
    optimizer = Ranger(model.parameters(), weight_decay=weight_decay,lr=lr)
    cls_loss_func = nn.BCELoss()
    reg_loss_func = nn.MSELoss()

    # save best model
    global_best_score = 0
    model_state = None

    # early stop settings 
    stop_rounds = 3
    no_improvement = 0
    current_score = None 

    for epoch in range(epochs):
        for step, (batch_feature,batch_user,batch_item,batch_cls,batch_reg) in enumerate(train_loader):
            model.train() # prep to train model
            batch_feature,batch_user,batch_item,batch_cls,batch_reg =  \
            batch_feature.to(device), batch_user.to(device), batch_item.to(device),\
             batch_cls.to(device), batch_reg.to(device)
            batch_cls,batch_reg = batch_cls.view(-1,1), batch_reg.view(-1,1)

            # model output
            classification_output, regression_output, hidden_vector = model(batch_feature,batch_user,batch_item)

            # FGSM attack
            adv_vector = fgsm_attack(model,cls_loss_func,hidden_vector,batch_cls,0.01)
            adv_output = model.module.pred_from_hidden(adv_vector) 

            # calculate loss
            adv_loss_func = nn.BCELoss(weight=batch_cls) 
            adv_loss = beta * adv_loss_func(adv_output,batch_cls) 
            cls_loss = cls_loss_func(classification_output,batch_cls)
            revenue_loss = alpha * reg_loss_func(regression_output, batch_reg)
            loss = cls_loss + revenue_loss + adv_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (step+1) % 1000 ==0:  
                print("CLS loss:%.4f, REG loss:%.4f, ADV loss:%.4f, Loss:%.4f"\
                %(cls_loss.item(),revenue_loss.item(),adv_loss.item(),loss.item()))
                
        # evaluate 
        model.eval()
        print("Validate at epoch %s"%(epoch+1))
        y_prob, val_loss = model.module.eval_on_batch(valid_loader)
        y_pred_tensor = torch.tensor(y_prob).float().to(device)
        best_threshold, val_score, roc = torch_threshold(y_prob,xgb_validy)
        overall_f1, auc, precisions, recalls, f1s, revenues = metrics(y_prob,xgb_validy,revenue_valid)
        select_best = np.mean(f1s)
        print("Over-all F1:%.4f, AUC:%.4f, F1-top:%.4f" % (overall_f1, auc, select_best) )

        print("Evaluate at epoch %s"%(epoch+1))
        y_prob, val_loss = model.module.eval_on_batch(test_loader)
        y_pred_tensor = torch.tensor(y_prob).float().to(device)
        overall_f1, auc, precisions, recalls, f1s, revenues = metrics(y_prob,xgb_testy,revenue_test,best_thresh=best_threshold)
        print("Over-all F1:%.4f, AUC:%.4f, F1-top:%.4f" %(overall_f1, auc, np.mean(f1s)) )

        # save best model 
        if select_best > global_best_score:
            global_best_score = select_best
            torch.save(model,model_path)
        
         # early stopping 
        if current_score == None:
            current_score = select_best
            continue
        if select_best < current_score:
            current_score = select_best
            no_improvement += 1
        if no_improvement >= stop_rounds:
            print("Early stopping...")
            break 
        if select_best > current_score:
            no_improvement = 0
            current_score = None
Esempio n. 6
0
def train(train_data, val_data, fold_idx=None):
    train_data = MyDataset(train_data, train_transform)
    train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True)

    val_data = MyDataset(val_data, val_transform)
    val_loader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False)

    model = Net(model_name).to(device)
    # criterion = nn.CrossEntropyLoss()
    criterion = FocalLoss(0.5)
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
    optimizer = Ranger(model.parameters(), lr=1e-3, weight_decay=0.0005)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=4)

    if fold_idx is None:
        print('start')
        model_save_path = os.path.join(config.model_path, '{}.bin'.format(model_name))
    else:
        print('start fold: {}'.format(fold_idx + 1))
        model_save_path = os.path.join(config.model_path, '{}_fold{}.bin'.format(model_name, fold_idx))
    # if os.path.isfile(model_save_path):
    #     print('加载之前的训练模型')
    #     model.load_state_dict(torch.load(model_save_path))

    best_val_score = 0
    best_val_score_cnt = 0
    last_improved_epoch = 0
    adjust_lr_num = 0
    for cur_epoch in range(config.epochs_num):
        start_time = int(time.time())
        model.train()
        print('epoch:{}, step:{}'.format(cur_epoch + 1, len(train_loader)))
        cur_step = 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            probs = model(batch_x)

            train_loss = criterion(probs, batch_y)
            train_loss.backward()
            optimizer.step()

            cur_step += 1
            if cur_step % config.train_print_step == 0:
                train_acc = accuracy(probs, batch_y)
                msg = 'the current step: {0}/{1}, train loss: {2:>5.2}, train acc: {3:>6.2%}'
                print(msg.format(cur_step, len(train_loader), train_loss.item(), train_acc[0].item()))
        val_loss, val_score = evaluate(model, val_loader, criterion)
        if val_score >= best_val_score:
            if val_score == best_val_score:
                best_val_score_cnt += 1
            best_val_score = val_score
            torch.save(model.state_dict(), model_save_path)
            improved_str = '*'
            last_improved_epoch = cur_epoch
        else:
            improved_str = ''
        msg = 'the current epoch: {0}/{1}, val loss: {2:>5.2}, val acc: {3:>6.2%}, cost: {4}s {5}'
        end_time = int(time.time())
        print(msg.format(cur_epoch + 1, config.epochs_num, val_loss, val_score,
                         end_time - start_time, improved_str))
        if cur_epoch - last_improved_epoch >= config.patience_epoch or best_val_score_cnt >= 3:
            if adjust_lr_num >= config.adjust_lr_num:
                print("No optimization for a long time, auto stopping...")
                break
            print("No optimization for a long time, adjust lr...")
            # scheduler.step()
            last_improved_epoch = cur_epoch  # 加上,不然会连续更新的
            adjust_lr_num += 1
            best_val_score_cnt = 0
        scheduler.step()
    del model
    gc.collect()

    if fold_idx is not None:
        model_score[fold_idx] = best_val_score
Esempio n. 7
0
def train(train_data, val_data, fold_idx=None):
    train_dataset = MyDataset(train_data, tokenizer)
    val_dataset = MyDataset(val_data, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size)

    model = Model().to(config.device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=model_config.learning_rate)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
    optimizer = Ranger(model.parameters(), lr=5e-5)
    period = int(len(train_loader) / config.train_print_step)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=period,
                                                           eta_min=5e-9)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.train_print_step, eta_min=1e-9)

    if fold_idx is None:
        print('start')
        model_save_path = os.path.join(config.model_path,
                                       '{}.bin'.format(model_name))
    else:
        print('start fold: {}'.format(fold_idx + 1))
        model_save_path = os.path.join(
            config.model_path, '{}_fold{}.bin'.format(model_name, fold_idx))

    best_val_score = 0
    last_improved_epoch = 0
    adjust_lr_num = 0
    y_true_list = []
    y_pred_list = []
    for cur_epoch in range(config.epochs_num):
        start_time = int(time.time())
        model.train()
        print('epoch:{}, step:{}'.format(cur_epoch + 1, len(train_loader)))
        cur_step = 0
        # for batch_x, batch_y in train_loader:
        for inputs in train_loader:
            tweet = inputs["tweet"]
            selected_text = inputs["selected_text"]
            sentiment = inputs["sentiment"]
            ids = inputs["ids"]
            attention_mask = inputs["attention_mask"]
            token_type_ids = inputs["token_type_ids"]
            targets_start_idx = inputs["start_idx"]
            targets_end_idx = inputs["end_idx"]
            offsets = inputs["offsets"]
            # Move ids, masks, and targets to gpu while setting as torch.long
            ids = ids.to(config.device, dtype=torch.long)
            token_type_ids = token_type_ids.to(config.device, dtype=torch.long)
            attention_mask = attention_mask.to(config.device, dtype=torch.long)
            targets_start_idx = targets_start_idx.to(config.device,
                                                     dtype=torch.long)
            targets_end_idx = targets_end_idx.to(config.device,
                                                 dtype=torch.long)
            optimizer.zero_grad()
            start_logits, end_logits = model(
                input_ids=ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )
            loss = loss_fn(start_logits, end_logits, targets_start_idx,
                           targets_end_idx)
            loss.backward()
            optimizer.step()
            cur_step += 1
            # pred_start_idxs = torch.argmax(start_logits, dim=1).cpu().data.numpy()
            # pred_end_idxs = torch.argmax(end_logits, dim=1).cpu().data.numpy()
            # for i in range(len(tweet)):
            #     y_true_list.append((tweet[i], selected_text[i], sentiment[i], offsets[i]))
            #     y_pred_list.append((pred_start_idxs[i], pred_end_idxs[i]))
            if cur_step % config.train_print_step == 0:
                scheduler.step()
                msg = 'the current step: {0}/{1}, cost: {2}s'
                print(
                    msg.format(cur_step, len(train_loader),
                               int(time.time()) - start_time))
            #     train_score = get_score(y_true_list, y_pred_list)
            #     msg = 'the current step: {0}/{1}, train score: {2:>6.2%}'
            #     print(msg.format(cur_step, len(train_loader), train_score))
            #     y_true_list = []
            #     y_pred_list = []
        val_loss, val_score = evaluate(model, val_loader)
        if val_score >= best_val_score:
            best_val_score = val_score
            torch.save(model.state_dict(), model_save_path)
            improved_str = '*'
            last_improved_epoch = cur_epoch
        else:
            improved_str = ''
        msg = 'the current epoch: {0}/{1}, val loss: {2:>5.2}, val acc: {3:>6.2%}, cost: {4}s {5}'
        end_time = int(time.time())
        print(
            msg.format(cur_epoch + 1, config.epochs_num, val_loss, val_score,
                       end_time - start_time, improved_str))
        if cur_epoch - last_improved_epoch >= config.patience_epoch:
            if adjust_lr_num >= model_config.adjust_lr_num:
                print("No optimization for a long time, auto stopping...")
                break
            print("No optimization for a long time, adjust lr...")
            last_improved_epoch = cur_epoch  # 加上,不然会连续更新的
            adjust_lr_num += 1

    del model
    gc.collect()

    if fold_idx is not None:
        model_score[fold_idx] = best_val_score
Esempio n. 8
0
def main(args):

    #    img = plt.imread(train_img_list[0])
    #    plt.imshow(img)
    '''create directory to save trained model and other info'''
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)


#    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ''' setup GPU '''
    #    torch.cuda.set_device(args.gpu)
    ''' setup random seed '''
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    iters = 0

    best_acc = 0
    best_recall = 0
    train_loss = []
    total_wrecall = [0]
    total_acc = [0]
    ''' load dataset and prepare data loader '''

    print('===> loading data')
    train_set = MyDataset(r'E:\ACV\MangoClassify', 'C1-P1_Train', 'train.csv',
                          'train')
    test_set = MyDataset(r'E:\ACV\MangoClassify', 'C1-P1_Dev', 'dev.csv',
                         'test')
    #    print(train_set[0])
    train_loss = []
    print('===> build dataloader ...')

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.train_batch,
                                               num_workers=args.workers,
                                               shuffle=False)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=args.test_batch,
                                              num_workers=args.workers,
                                              shuffle=False)
    ''' load model '''
    print('===> prepare model ...')
    model = CNN().to(device)
    ''' define loss '''
    criterion = nn.CrossEntropyLoss()
    ''' setup optimizer '''
    #    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    optimizer = Ranger(model.parameters(), lr=args.lr)
    ''' train model '''
    print('===> start training ...')

    for epoch in range(1, args.epoch + 1):
        #        model.train()
        for idx, (imgs, label) in enumerate(train_loader):
            #            print(imgs, label)
            train_info = 'Epoch: [{0}][{1}/{2}]'.format(
                epoch, idx + 1, len(train_loader))
            iters += 1
            ''' move data to gpu '''
            #            print('===> load data to gpu')
            imgs = imgs.permute(0, 3, 1, 2).to(device, dtype=torch.float)
            label = label.to(device)
            ''' forward path '''
            pred = model(imgs)
            ''' compute loss, backpropagation, update parameters '''
            #            print('===> calculate loss')
            loss = criterion(pred, label)  # compute loss
            train_loss += [loss.item()]
            torch.cuda.empty_cache()
            optimizer.zero_grad()  # set grad of all parameters to zero
            loss.backward()  # compute gradient for each parameters
            optimizer.step()  # update parameters

            train_info += ' loss: {:.8f}'.format(loss.item())

            print(train_info)  #, end="\r")
        if epoch % args.val_epoch == 0:
            ''' evaluate the model '''
            test_info = 'Epoch: [{}] '.format(epoch)
            model.eval()

            correct = 0
            total = 0
            tp_A = 0
            tp_B = 0
            tp_C = 0
            fn_A = 0
            fn_B = 0
            fn_C = 0
            #            loss = 0
            for idx, (imgs, label) in enumerate(test_loader):
                imgs = imgs.permute(0, 3, 1, 2).to(device, dtype=torch.float)

                #                gt = gt.to(device)
                pred = model(imgs)
                #                torch.cuda.empty_cache()
                #                loss += criterion(output, gt).item()
                a, b, c, d, e, f, g, h, i = confusion_matrix(
                    label.detach().numpy(),
                    pred.argmax(-1).cpu().detach().numpy()).ravel()
                tp_A += a
                fn_A += (a + b + c)
                tp_B += e
                fn_B += (e + d + f)
                tp_C += i
                fn_C += (i + g + h)
                correct += (a + e + i)
                total += len(label)
            acc = correct / total
            w_recall = ((tp_A / fn_A) + (tp_B / fn_B) + (tp_B / fn_B)) / 3
            total_wrecall += [w_recall]
            total_acc += [acc]
            test_info += 'Acc:{:.8f} '.format(acc)
            test_info += 'Recall:{:.8f} '.format(w_recall)

            print(test_info)
            #            print(tn, fp, fn, tp)
            ''' save best model '''
            if w_recall > best_recall:
                best_recall = w_recall
                save_model(model,
                           os.path.join(args.save_dir, 'model_best_recall.h5'))

            if acc > best_acc:
                best_acc = acc
                save_model(model,
                           os.path.join(args.save_dir, 'model_best_acc.h5'))
        ''' save model '''
        save_model(
            model,
            os.path.join(
                args.save_dir, 'model_{}_acc={:8f}_recall={:8f}.h5'.format(
                    epoch, acc, w_recall)))

    plt.figure()
    plt.plot(range(1, len(train_loss) + 1), train_loss, '-')
    plt.xlabel("iteration")
    plt.ylabel("loss")
    plt.title("training loss")
    plt.show()

    plt.figure()
    plt.plot(range(1, len(total_acc) + 1), total_acc, '-')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Best Accuracy:" + str(best_acc))
    plt.show()

    plt.figure()
    plt.plot(range(1, len(total_wrecall) + 1), total_wrecall, '-')
    plt.xlabel("Epoch")
    plt.ylabel("Recall")
    plt.title("Best Recall:" + str(best_recall))
    plt.show()
Esempio n. 9
0
def train_model(dataset=dataset,
                save_dir=save_dir,
                num_classes=num_classes,
                lr=lr,
                num_epochs=nEpochs,
                save_epoch=snapshot,
                useTest=useTest,
                test_interval=nTestInterval):
    """
        Args:
            num_classes (int): Number of classes in the data
            num_epochs (int, optional): Number of epochs to train for.
    """
    file = open('run/log.txt', 'w')

    if modelName == 'C3D':
        model = C3D(num_class=num_classes)
        model.my_load_pretrained_weights('saved_model/c3d.pickle')
        train_params = model.parameters()
        # train_params = [{'params': get_1x_lr_params(model), 'lr': lr},
        #                 {'params': get_10x_lr_params(model), 'lr': lr * 10}]
    # elif modelName == 'R2Plus1D':
    #     model = R2Plus1D_model.R2Plus1DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2))
    #     train_params = [{'params': R2Plus1D_model.get_1x_lr_params(model), 'lr': lr},
    #                     {'params': R2Plus1D_model.get_10x_lr_params(model), 'lr': lr * 10}]
    # elif modelName == 'R3D':
    #     model = R3D_model.R3DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2))
    #     train_params = model.parameters()
    elif modelName == 'Res3D':
        # model = Resnet(num_classes=num_classes, block=resblock, layers=[3, 4, 6, 3])
        # train_params=model.parameters()
        model = generate_model(50)
        model = load_pretrained_model(model,
                                      './saved_model/r3d50_K_200ep.pth',
                                      n_finetune_classes=num_classes)
        train_params = model.parameters()
    else:
        print('We only implemented C3D and R2Plus1D models.')
        raise NotImplementedError
    criterion = nn.CrossEntropyLoss(
    )  # standard crossentropy loss for classification
    # optimizer = torch.optim.Adam(train_params, lr=lr, betas=(0.9, 0.999), weight_decay=1e-5,
    #                              amsgrad=True)
    optimizer = Ranger(train_params,
                       lr=lr,
                       betas=(.95, 0.999),
                       weight_decay=5e-4)
    print('use ranger')

    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=32,
                                  eta_min=0,
                                  last_epoch=-1)
    # optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,
    #                                       gamma=0.1)  # the scheduler divides the lr by 10 every 10 epochs
    if resume_epoch == 0:
        print("Training {} from scratch...".format(modelName))
    else:
        checkpoint = torch.load(os.path.join(
            save_dir, 'models',
            saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'),
                                map_location=lambda storage, loc: storage
                                )  # Load all tensors onto the CPU
        print("Initializing weights from: {}...".format(
            os.path.join(
                save_dir, 'models',
                saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar')))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['opt_dict'])

    print('Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    # model.to(device)
    if torch.cuda.is_available():
        model = model.cuda()
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        model = nn.DataParallel(model)
        criterion.cuda()

    # log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    log_dir = os.path.join(save_dir)
    writer = SummaryWriter(log_dir=log_dir)

    print('Training model on {} dataset...'.format(dataset))
    train_dataloader = DataLoader(VideoDataset(dataset=dataset,
                                               split='train',
                                               clip_len=16),
                                  batch_size=8,
                                  shuffle=True,
                                  num_workers=8)
    val_dataloader = DataLoader(VideoDataset(dataset=dataset,
                                             split='validation',
                                             clip_len=16),
                                batch_size=8,
                                num_workers=8)
    test_dataloader = DataLoader(VideoDataset(dataset=dataset,
                                              split='test',
                                              clip_len=16),
                                 batch_size=8,
                                 num_workers=8)

    trainval_loaders = {'train': train_dataloader, 'val': val_dataloader}
    trainval_sizes = {
        x: len(trainval_loaders[x].dataset)
        for x in ['train', 'val']
    }
    test_size = len(test_dataloader.dataset)
    # my_smooth={'0': 0.88, '1': 0.95, '2': 0.96, '3': 0.79, '4': 0.65, '5': 0.89, '6': 0.88}
    for epoch in range(resume_epoch, num_epochs):
        # each epoch has a training and validation step
        for phase in ['train', 'val']:
            start_time = timeit.default_timer()

            # reset the running loss and corrects
            running_loss = 0.0
            running_corrects = 0.0
            # set model to train() or eval() mode depending on whether it is trained
            # or being validated. Primarily affects layers such as BatchNorm or Dropout.
            if phase == 'train':
                # scheduler.step() is to be called once every epoch during training
                # scheduler.step()
                model.train()
            else:
                model.eval()

            for inputs, labels in tqdm(trainval_loaders[phase]):
                # move inputs and labels to the device the training is taking place on
                inputs = Variable(inputs, requires_grad=True).to(device)
                labels = Variable(labels).to(device)
                # inputs = inputs.cuda(non_blocking=True)
                # labels = labels.cuda(non_blocking=True)
                optimizer.zero_grad()

                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)

                probs = nn.Softmax(dim=1)(outputs)
                # the size of output is [bs , 7]
                preds = torch.max(probs, 1)[1]
                # preds is the index of maxnum of output
                # print(outputs)
                # print(torch.max(outputs, 1))

                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    scheduler.step(loss)
                # for name, parms in model.named_parameters():
                #     print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
                #           ' -->grad_value:', parms.grad)
                #     print('-->name:', name, ' -->grad_value:', parms.grad)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                print('\ntemp/label:{}/{}'.format(preds[0], labels[0]))

            epoch_loss = running_loss / trainval_sizes[phase]
            epoch_acc = running_corrects.double() / trainval_sizes[phase]

            if phase == 'train':
                writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch)
                writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch)
            else:
                writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch)
                writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch)

            print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format(
                phase, epoch + 1, nEpochs, epoch_loss, epoch_acc))
            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")
            file.write("\n[{}] Epoch: {}/{} Loss: {} Acc: {}".format(
                phase, epoch + 1, nEpochs, epoch_loss, epoch_acc))

        if epoch % save_epoch == (save_epoch - 1):
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'opt_dict': optimizer.state_dict(),
                },
                os.path.join(save_dir,
                             saveName + '_epoch-' + str(epoch) + '.pth.tar'))
            print("Save model at {}\n".format(
                os.path.join(save_dir,
                             saveName + '_epoch-' + str(epoch) + '.pth.tar')))

        if useTest and epoch % test_interval == (test_interval - 1):
            model.eval()
            start_time = timeit.default_timer()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.no_grad():
                    outputs = model(inputs)
                probs = nn.Softmax(dim=1)(outputs)
                preds = torch.max(probs, 1)[1]
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / test_size
            epoch_acc = running_corrects.double() / test_size

            writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
            writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch)

            print("[test] Epoch: {}/{} Loss: {} Acc: {}".format(
                epoch + 1, nEpochs, epoch_loss, epoch_acc))
            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")
            file.write("\n[test] Epoch: {}/{} Loss: {} Acc: {}\n".format(
                epoch + 1, nEpochs, epoch_loss, epoch_acc))
    writer.close()
    file.close()
Esempio n. 10
0
class TOMTrainer:
    def __init__(self,
                 gen,
                 dis,
                 dataloader_train,
                 dataloader_val,
                 gpu_id,
                 log_freq,
                 save_dir,
                 n_step,
                 optimizer='adam'):
        if torch.cuda.is_available():
            self.device = torch.device('cuda:' + str(gpu_id))
        else:
            self.device = torch.device('cpu')
        self.gen = gen.to(self.device)
        self.dis = dis.to(self.device)

        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val

        if optimizer == 'adam':
            self.optim_g = torch.optim.Adam(gen.parameters(),
                                            lr=1e-4,
                                            betas=(0.5, 0.999))
            self.optim_d = torch.optim.Adam(dis.parameters(),
                                            lr=1e-4,
                                            betas=(0.5, 0.999))
        elif optimizer == 'ranger':
            self.optim_g = Ranger(gen.parameters())
            self.optim_d = Ranger(dis.parameters())

        self.criterionL1 = nn.L1Loss()
        self.criterionVGG = VGGLoss()
        self.criterionAdv = torch.nn.BCELoss()
        self.log_freq = log_freq
        self.save_dir = save_dir
        self.n_step = n_step
        self.step = 0
        print('Generator Parameters:',
              sum([p.nelement() for p in self.gen.parameters()]))
        print('Discriminator Parameters:',
              sum([p.nelement() for p in self.dis.parameters()]))

    def train(self, epoch):
        """Iterate 1 epoch over train data and return loss
        """
        return self.iteration(epoch, self.dataloader_train)

    def val(self, epoch):
        """Iterate 1 epoch over validation data and return loss
        """
        return self.iteration(epoch, self.dataloader_val, train=False)

    def iteration(self, epoch, data_loader, train=True):
        data_iter = tqdm(enumerate(data_loader),
                         desc='epoch: %d' % (epoch),
                         total=len(data_loader),
                         bar_format='{l_bar}{r_bar}')

        total_loss = 0.0
        FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available(
        ) else torch.FloatTensor

        for i, _data in data_iter:
            data = {}
            for key, value in _data.items():
                if not 'name' in key:
                    data[key] = value.to(self.device)  # Load data on GPU
            cloth = data['cloth']
            cloth_mask = data['cloth_mask']
            person = data['person']
            batch_size = person.shape[0]

            outputs = self.gen(torch.cat([data['feature'], cloth],
                                         1))  # (batch, channel, height, width)
            rendered_person, composition_mask = torch.split(outputs, 3, 1)
            rendered_person = torch.tanh(rendered_person)
            composition_mask = torch.sigmoid(composition_mask)
            tryon_person = cloth * composition_mask + rendered_person * (
                1 - composition_mask)
            visuals = [[data['head'], data['shape'], data['pose']],
                       [cloth, cloth_mask * 2 - 1, composition_mask * 2 - 1],
                       [rendered_person, tryon_person, person]]

            # Adversarial ground truths
            real = Variable(FloatTensor(batch_size, 1).fill_(1.0),
                            requires_grad=False)  # Batch size
            fake = Variable(FloatTensor(batch_size, 1).fill_(0.0),
                            requires_grad=False)
            # Loss measures generator's ability to fool the discriminator
            l_l1 = self.criterionL1(tryon_person, person)
            l_mask = self.criterionL1(composition_mask, cloth_mask)
            l_vgg = self.criterionVGG(tryon_person, person)
            metric = measure.compare_ssim(tryon_person,
                                          person,
                                          multichannel=True)
            dis_fake = self.dis(
                torch.cat([data['feature'], cloth, tryon_person],
                          1))  # Dis forward
            l_adv = self.criterionAdv(dis_fake, real)
            loss_g = l_l1 + l_vgg + l_mask + l_adv / batch_size
            # Loss for discriminator
            loss_d = ( self.criterionAdv(self.dis(torch.cat([data['feature'], cloth, person],1)), real) +\
                        self.criterionAdv(self.dis(torch.cat([data['feature'], cloth, tryon_person],1).detach()), fake) )\
                        / 2

            if train:
                self.optim_g.zero_grad()
                loss_g.backward()
                self.optim_g.step()
                self.optim_d.zero_grad()
                loss_d.backward()
                self.optim_d.step()
                self.step += 1

            total_loss = total_loss + loss_g.item() + loss_d.item()
            post_fix = {
                'epoch': epoch,
                'iter': i,
                'avg_loss': total_loss / (i + 1),
                'loss_recon': l_l1.item() + l_vgg.item() + l_mask.item(),
                'loss_g': l_adv.item(),
                'loss_d': loss_d.item(),
                'ssim': metric
            }
            if train and i % self.log_freq == 0:
                data_iter.write(str(post_fix))
                board_add_images(visuals, epoch, i,
                                 os.path.join(self.save_dir, 'train'))

        return total_loss / len(data_iter)
Esempio n. 11
0
def train(model, exp_id):
    save_model_dir = f'experiment/{exp_id}/save_models'
    try:
        os.mkdir(save_model_dir)
    except Exception as e:
        print(e)

    # ==== INIT MODEL=================
    device = get_device()

    model.to(device)
    optimizer = Ranger(model.parameters(),
                       lr=cfg["model_params"]["lr"],
                       weight_decay=0.0001)

    
    train_dataloader, gt_dict = load_tune_data()

    start = time.time()

    train_result = {
        'epoch': [],
        'iter': [],
        'loss[-k:](avg)': [],
        'validation_loss': [],
        'time(minute)': [],
    }
    def append_train_result(epoch, iter, avg_lossk, validation_loss, run_time):
        train_result['epoch'].append(epoch)
        train_result['iter'].append(iter)
        train_result['loss[-k:](avg)'].append(avg_lossk)
        train_result['validation_loss'].append(validation_loss)
        train_result['time(minute)'].append(run_time)


    k = 1000
    lossk = []
    torch.set_grad_enabled(True)
    num_iter = len(train_dataloader)
    print(num_iter)
    for epoch in range(cfg['train_params']['epoch']):
        model.train()
        torch.set_grad_enabled(True)

        tr_it = iter(train_dataloader)
        train_progress_bar = tqdm(range(num_iter))
        optimizer.zero_grad()
        print('epoch:', epoch)
        for i in train_progress_bar:
            try:
                data = next(tr_it)
                preds, confidences = forward(data, model, device)

                # convert to world positions
                world_from_agents = data['world_from_agent'].float().to(device)
                centroids = data['centroid'].float().to(device)
                world_from_agents = world_from_agents.unsqueeze(1) # bs * 1 * 3 * 3
                world_from_agents = world_from_agents.repeat(1, 3, 1, 1) # bs * 1 * 3 * 3
                centroids = centroids.unsqueeze(1).unsqueeze(1)
                centroids = centroids.repeat(1, 3, 50, 1)
                preds = transform_ts_points(preds, world_from_agents.clone()) - centroids[:, :, :, :2].clone()
                
                # get ground_truth
                target_availabilities = []
                target_positions = []
                for track_id, timestamp in zip(data['track_id'], data['timestamp']):
                    key = str(track_id.item()) + str(timestamp.item())
                    target_positions.append(torch.tensor(gt_dict[key]['coord']))
                    target_availabilities.append(torch.tensor(gt_dict[key]['avail']))
                
                target_availabilities = torch.stack(target_availabilities).to(device)
                target_positions = torch.stack(target_positions).to(device)

                loss = criterion(target_positions, preds, confidences, target_availabilities)
                
                # Backward pass
                optimizer.zero_grad()
                
                loss.backward()
                optimizer.step()

                lossk.append(loss.item())
                if len(lossk) > k:
                    lossk.pop(0)

                train_progress_bar.set_description(
                    f"loss: {loss.item():.4f} loss[-k:](avg): {np.mean(lossk):.4f}")

                if ((i > 0 and i % cfg['train_params']['checkpoint_steps'] == 0) or i == num_iter-1):
                    # save model per checkpoint
                    torch.save(model.state_dict(),
                               f'{save_model_dir}/epoch{epoch:02d}_iter{i:05d}.pth')
                    append_train_result(epoch, i, np.mean(lossk), -1, (time.time()-start)/60)


                if ((i > 0 and i % cfg['train_params']['validation_steps'] == 0 )
                        or i == num_iter-1):
                    validation_loss = validation(model, device)
                    append_train_result(epoch, i, -1, validation_loss, (time.time()-start)/60)
                    model.train()
                    torch.set_grad_enabled(True)

            except KeyboardInterrupt:
                torch.save(model.state_dict(),
                           f'{save_model_dir}/interrupt_epoch{epoch:02d}_iter{i:05d}.pth')
                # save train result
                results = pd.DataFrame(train_result)
                results.to_csv(
                    f"experiment/{exp_id}/interrupt_train_result.csv", index=False)
                print(f"Total training time is {(time.time()-start)/60} mins")
                print(results)
                raise KeyboardInterrupt
        
        # torch.save(model.state_dict(), f'{save_model_dir}/epoch{epoch:02d}_iter{i:05d}.pth')
        del tr_it, train_progress_bar

    # save train result
    results = pd.DataFrame(train_result)
    results.to_csv(f"experiment/{exp_id}/train_result.csv", index=False)
    print(f"Total training time is {(time.time()-start)/60} mins")
    print(results)
Esempio n. 12
0
def main(conf):
    if conf['method'] not in ['baseline', 'LmixLact', 'Lmix']:
        raise ValueError("method must be baseline, LmixLact or Lmix")

    # Set random seeds both for pytorch and numpy
    th.manual_seed(conf['seed'])
    np.random.seed(conf['seed'])

    # Create experiment folder and save conf file with the final configuration
    os.makedirs(conf['exp_dir'], exist_ok=True)
    conf_path = os.path.join(conf['exp_dir'], 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Load test set. Be careful about is_wav!
    test_set = musdb.DB(root=conf['musdb_path'], subsets=["test"], is_wav=True)

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf['exp_dir'], 'examples/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])

    # If stop_index==-1, evaluate the whole test set
    if conf['stop_index'] == -1:
        conf['stop_index'] = len(test_set)

    # prepare data frames
    results_applyact = museval.EvalStore()
    results_adapt = museval.EvalStore()
    silence_adapt = pd.DataFrame({
        'target': [],
        'PES': [],
        'EPS': [],
        'track': []
    })

    # Loop over test examples
    for idx in range(len(test_set)):
        torch.set_grad_enabled(False)
        track = test_set.tracks[idx]
        print(idx, str(track.name))

        # Create local directory
        local_save_dir = os.path.join(ex_save_dir, str(track.name))
        os.makedirs(local_save_dir, exist_ok=True)

        # Load mixture
        mix = th.from_numpy(track.audio).t().float()
        ref = mix.mean(dim=0)  # mono mixture
        mix = (mix - ref.mean()) / ref.std()

        # Load pretrained model
        klass, args, kwargs, state = torch.load(conf['model_path'], 'cpu')
        model = klass(*args, **kwargs)
        model.load_state_dict(state)

        # Handle device placement
        if conf['use_gpu']:
            model.cuda()
        device = next(model.parameters()).device

        # Create references matrix
        references = th.stack([
            th.from_numpy(track.targets[name].audio) for name in source_names
        ])
        references = references.numpy()

        # Get activations
        H = []
        for name in source_names:
            audio = track.targets[name].audio
            H.append(audio)
        H = np.array(H)
        _, bn_ch1, _ = compute_activation_confidence(H[:, :, 0],
                                                     theta=conf['th'],
                                                     hilb=False)
        _, bn_ch2, _ = compute_activation_confidence(H[:, :, 1],
                                                     theta=conf['th'],
                                                     hilb=False)
        activations = th.from_numpy(np.stack((bn_ch1, bn_ch2), axis=2))

        # FINE TUNING
        if conf['method'] != 'baseline':
            print('ADAPTATION')
            torch.set_grad_enabled(True)

            # Freeze layers
            freeze(model.encoder)
            freeze(model.separator, n=conf['frozen_layers'])
            if conf['freeze_decoder']:
                freeze(model.decoder)

            # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=conf['lr_fine'])
            optimizer = Ranger(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=conf['lr_fine'])
            loss_func = nn.L1Loss()

            # Initialize writer for Tensorboard
            writer = SummaryWriter(log_dir=local_save_dir)
            for epoch in range(conf['epochs_fine']):
                total_loss = 0
                epoch_loss = 0

                total_rec = 0
                epoch_rec = 0
                total_act = 0
                epoch_act = 0

                if conf['monitor_metrics']:
                    total_sdr = dict([(key, 0) for key in source_names])
                    epoch_sdr = dict([(key, 0) for key in source_names])

                    total_sir = dict([(key, 0) for key in source_names])
                    epoch_sir = dict([(key, 0) for key in source_names])

                    total_sar = dict([(key, 0) for key in source_names])
                    epoch_sar = dict([(key, 0) for key in source_names])

                    total_isr = dict([(key, 0) for key in source_names])
                    epoch_isr = dict([(key, 0) for key in source_names])

                # Data loader with eventually data augmentation
                mix_set = DAdataloader(mix.numpy(),
                                       win=conf['win_fine'],
                                       hop=conf['hop_fine'],
                                       sample_rate=conf['sample_rate'],
                                       n_observations=conf['n_observations'],
                                       pitch_list=conf['pitch_list'],
                                       min_semitones=conf['min_semitones'],
                                       max_semitones=conf['max_semitones'],
                                       same_pitch_list_all_chunks=conf[
                                           'same_pitch_list_all_chunks'])

                # Iterate over chuncks
                for t, item in enumerate(mix_set):
                    sample, win, _ = item
                    mix_chunk = th.from_numpy(sample[None, :, :]).to(device)
                    est_chunk = model(cp(mix_chunk))

                    act_chunk = activations[None, :,
                                            win, :].transpose(3, 2).to(device)
                    loss_act = loss_func(est_chunk * (1 - act_chunk),
                                         torch.zeros_like(est_chunk))

                    if conf['method'] == 'LmixLact':
                        loss_rec = loss_func(
                            mix_chunk, torch.sum(est_chunk * act_chunk, dim=1))
                        loss = loss_rec + conf['gamma'] * loss_act

                    if conf['method'] == 'Lmix':
                        loss_rec = loss_func(mix_chunk,
                                             torch.sum(est_chunk, dim=1))
                        loss = loss_rec

                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                    total_loss += loss.item()
                    epoch_loss = total_loss / (1 + t)

                    total_rec += loss_rec.item()
                    total_act += loss_act.item()

                    epoch_rec = total_rec / (1 + t)
                    epoch_act = total_act / (1 + t)

                    # Monitor sdr, sir, and sar over epochs
                    if conf['monitor_metrics']:
                        ref_chunk = references[:, win, :]
                        skip = False
                        for i, target in enumerate(source_names):
                            if np.sum(ref_chunk[i, :, :]**2) == 0:
                                skip = True
                        if not skip:
                            sdr, isr, sir, sar = museval.evaluate(
                                ref_chunk,
                                est_chunk.squeeze().transpose(
                                    1, 2).detach().cpu().numpy(),
                                win=np.inf)

                            sdr = np.array(sdr)
                            sir = np.array(sir)
                            sar = np.array(sar)
                            isr = np.array(isr)

                            for i, target in enumerate(source_names):
                                total_sdr[target] += sdr[i]
                                epoch_sdr[target] = total_sdr[target] / (1 + t)

                                total_sir[target] += sir[i]
                                epoch_sir[target] = total_sir[target] / (1 + t)

                                total_sar[target] += sar[i]
                                epoch_sar[target] = total_sar[target] / (1 + t)

                                total_isr[target] += isr[i]
                                epoch_isr[target] = total_isr[target] / (1 + t)

                if conf['monitor_metrics']:
                    for i, target in enumerate(source_names):
                        writer.add_scalar("SDR/" + target, epoch_sdr[target],
                                          epoch)
                        writer.add_scalar("SIR/" + target, epoch_sir[target],
                                          epoch)
                        writer.add_scalar("SAR/" + target, epoch_sar[target],
                                          epoch)
                        writer.add_scalar("ISR/" + target, epoch_isr[target],
                                          epoch)

                writer.add_scalar("Loss/total", epoch_loss, epoch)
                writer.add_scalar("Loss/rec", epoch_rec, epoch)
                writer.add_scalar("Loss/act", epoch_act, epoch)
                print('epoch, nr of training examples and loss: ', epoch, t,
                      epoch_loss, epoch_rec, epoch_act, epoch_sdr['other'])

            writer.flush()
            writer.close()

        # apply model
        print('Apply model')
        estimates = apply_model(model,
                                mix.to(device),
                                shifts=conf['shifts'],
                                split=conf['split'])
        estimates = estimates * ref.std() + ref.mean()
        estimates = estimates.transpose(1, 2).cpu().numpy()

        # get results of this track
        print('Evaluate model')
        assert references.shape == estimates.shape
        track_store, silence_frames = evaluate_mia(ref=references,
                                                   est=estimates,
                                                   track_name=track.name,
                                                   source_names=source_names,
                                                   eval_silence=True,
                                                   conf=conf)

        # aggregate results over the track and save the partials
        silence_adapt = silence_adapt.append(silence_frames, ignore_index=True)
        silence_adapt.to_json(os.path.join(conf['exp_dir'], 'silence.json'),
                              orient='records')

        results_adapt.add_track(track_store)
        results_adapt.save(os.path.join(conf['exp_dir'],
                                        'bss_eval_tracks.pkl'))
        print(results_adapt)

        # Save some examples with corresponding metrics in a folder
        if idx in save_idx:
            silence_frames.to_json(os.path.join(local_save_dir,
                                                'silence_frames.json'),
                                   orient='records')
            with open(os.path.join(local_save_dir, 'metrics_museval.json'),
                      'w+') as f:
                f.write(track_store.json)
            sf.write(os.path.join(local_save_dir, "mixture.wav"),
                     mix.transpose(0, 1).cpu().numpy(), conf['sample_rate'])
            for name, estimate, reference, activation in zip(
                    source_names, estimates, references, activations):
                print(name)

                unique, counts = np.unique(activation, return_counts=True)
                print(dict(zip(unique, counts / (len(activation) * 2) * 100)))

                assert estimate.shape == reference.shape
                sf.write(os.path.join(local_save_dir, name + "_est.wav"),
                         estimate, conf['sample_rate'])
                sf.write(os.path.join(local_save_dir, name + "_ref.wav"),
                         reference, conf['sample_rate'])
                sf.write(os.path.join(local_save_dir, name + "_act.wav"),
                         activation.cpu().numpy(), conf['sample_rate'])

        # Evaluate results when applying the activations to the output
        if conf['apply_act_output']:
            track_store_applyact, _ = evaluate_mia(ref=references,
                                                   est=estimates *
                                                   activations.cpu().numpy(),
                                                   track_name=track.name,
                                                   source_names=source_names,
                                                   eval_silence=False,
                                                   conf=conf)

            # aggregate results over the track and save the partials
            results_applyact.add_track(track_store_applyact)
            print('after applying activations')
            print(results_applyact)

            results_applyact.save(
                os.path.join(conf['exp_dir'], 'bss_eval_tracks_applyact.pkl'))

            # Save some examples with corresponding metrics in a folder
            if idx in save_idx:
                with open(
                        os.path.join(local_save_dir,
                                     'metrics_museval_applyact.json'),
                        'w+') as f:
                    f.write(track_store_applyact.json)

            del track_store_applyact

        # Delete some variables
        del references, mix, estimates, track, track_store, silence_frames, model

        # Stop if reached the limit
        if idx == conf['stop_index']:
            break

        print('------------------')

    # Print and save aggregated results
    print('Final results')
    print(results_adapt)
    method = museval.MethodStore()
    method.add_evalstore(results_adapt, conf['exp_dir'])
    method.save(os.path.join(conf['exp_dir'], 'bss_eval.pkl'))

    if conf['eval_silence']:
        print(
            "mean over evaluation frames, mean over channels, mean over tracks"
        )
        for target in source_names:
            print(
                target + ' ==>',
                silence_adapt.loc[silence_adapt['target'] == target].mean(
                    axis=0, skipna=True))
        silence_adapt.to_json(os.path.join(conf['exp_dir'], 'silence.json'),
                              orient='records')

    print('Final results apply act')
    print(results_applyact)
    method = museval.MethodStore()
    method.add_evalstore(results_applyact, conf['exp_dir'])
    method.save(os.path.join(conf['exp_dir'], 'bss_eval_applyact.pkl'))
Esempio n. 13
0
    #         sched = CosineDecay(
    #             opt, total_steps=len(loaders['train']) * (epochs - warmup_steps + 1),
    #             linear_start=eta_min, linear_frac=0.1, min_lr=3e-6)
    #         global_step = 0

    iteration = dict(epoch=epoch, train_loss=list(), valid_loss=list())

    for name, loader in loaders.items():
        is_training = name == 'train'
        count = 0
        metric = 0.0

        with torch.set_grad_enabled(is_training):
            for batch_no, batch in enumerate(loader):
                steps[name] += 1
                opt.zero_grad()

                # y = batch['site1']['targets'].to(device)
                y = batch['site1']['targets_one_hot'].to(device)

                out = model(batch['site1']['features'].to(device),
                            batch['site2']['features'].to(device))

                if is_training:
                    global_step += 1
                    loss = loss_fn(out, y)
                    loss.backward()
                    opt.step()

                    #if (epoch == (warmup_steps - 1)) and batch_no == (len(loader) - 1):
                    #    pass  # skip
Esempio n. 14
0
def train_alphaBert_stage1(TS_model,
                           dloader,
                           testloader,
                           lr=1e-4,
                           epoch=10,
                           log_interval=20,
                           cloze_fix=True,
                           use_amp=False,
                           lkahead=False,
                           parallel=True):
    global checkpoint_file
    TS_model.to(device)
    #    model_optimizer = optim.Adam(TS_model.parameters(), lr=lr)
    #    if lkahead:
    #        print('using Lookahead')
    #        model_optimizer = lookahead_pytorch.Lookahead(model_optimizer, la_steps=5, la_alpha=0.5)
    model_optimizer = Ranger(TS_model.parameters(), lr=lr)
    if use_amp:
        TS_model, model_optimizer = amp.initialize(TS_model,
                                                   model_optimizer,
                                                   opt_level="O1")
    if parallel:
        TS_model = torch.nn.DataParallel(TS_model)


#    torch.distributed.init_process_group(backend='nccl',
#                                         init_method='env://host',
#                                         world_size=0,
#                                         rank=0,
#                                         store=None,
#                                         group_name='')
#    TS_model = DDP(TS_model)
#    TS_model = apex.parallel.DistributedDataParallel(TS_model)
    TS_model.train()

    #    criterion = alphabert_loss.Alphabert_satge1_loss(device=device)
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
    iteration = 0
    total_loss = []
    out_pred_res = []
    out_pred_test = []
    for ep in range(epoch):
        t0 = time.time()
        #        step_loss = 0
        epoch_loss = 0
        epoch_cases = 0
        for batch_idx, sample in enumerate(dloader):
            #            TS_model.train()
            model_optimizer.zero_grad()
            loss = 0

            src = sample['src_token']
            trg = sample['trg']
            att_mask = sample['mask_padding']
            origin_len = sample['origin_seq_length']

            bs, max_len = src.shape

            #            src, err_cloze = make_cloze(src,
            #                                        max_len,
            #                                        device=device,
            #                                        percent=0.15,
            #                                        fix=cloze_fix)

            src = src.float().to(device)
            trg = trg.long().to(device)
            att_mask = att_mask.float().to(device)
            origin_len = origin_len.to(device)

            prediction_scores, = TS_model(input_ids=src,
                                          attention_mask=att_mask)

            #            print(1111,prediction_scores.view(-1,84).shape)
            #            print(1111,trg.view(-1).shape)

            loss = criterion(
                prediction_scores.view(-1, 100).contiguous(),
                trg.view(-1).contiguous())

            if use_amp:
                with amp.scale_loss(loss, model_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            model_optimizer.step()

            with torch.no_grad():
                epoch_loss += loss.item() * bs
                epoch_cases += bs

                if iteration % log_interval == 0:
                    print('Ep:{} [{} ({:.0f}%)/ ep_time:{:.0f}min] L:{:.4f}'.
                          format(ep, batch_idx * batch_size,
                                 100. * batch_idx / len(dloader),
                                 (time.time() - t0) * len(dloader) /
                                 (60 * (batch_idx + 1)), loss.item()))

                if iteration % 400 == 0:
                    save_checkpoint(checkpoint_file,
                                    'd2s_total.pth',
                                    TS_model,
                                    model_optimizer,
                                    parallel=parallel)
                    a_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(a_)
                    print(' ******** ******** ******** ')
                    _, show_pred = torch.max(prediction_scores[0], dim=1)
                    err_cloze_ = trg[0] > -1
                    src[0][err_cloze_] = show_pred[err_cloze_].float()
                    b_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(b_)
                    print(' ******** ******** ******** ')
                    src[0][err_cloze_] = trg[0][err_cloze_].float()
                    c_ = tokenize_alphabets.convert_idx2str(
                        src[0][:origin_len[0]])
                    print(c_)

                    out_pred_res.append((ep, a_, b_, c_, err_cloze_))
                    out_pd_res = pd.DataFrame(out_pred_res)
                    out_pd_res.to_csv('./result/out_pred_train.csv', sep=',')

                if iteration % 999 == 0:
                    print(' ===== Show the Test of Pretrain ===== ')
                    test_res = test_alphaBert_stage1(TS_model, testloader)
                    print(' ===== Show the Test of Pretrain ===== ')

                    out_pred_test.append((ep, *test_res))
                    out_pd_test = pd.DataFrame(out_pred_test)
                    out_pd_test.to_csv('./result/out_pred_test.csv', sep=',')

            iteration += 1
        if ep % 1 == 0:
            save_checkpoint(checkpoint_file,
                            'd2s_total.pth',
                            TS_model,
                            model_optimizer,
                            parallel=parallel)

            print('======= epoch:%i ========' % ep)

        print('++ Ep Time: {:.1f} Secs ++'.format(time.time() - t0))
        total_loss.append(float(epoch_loss / epoch_cases))
        pd_total_loss = pd.DataFrame(total_loss)
        pd_total_loss.to_csv('./result/total_loss_pretrain.csv', sep=',')
    print(total_loss)
Esempio n. 15
0
def train(model, exp_id):
    save_model_dir = f'experiment/{exp_id}/save_models'
    try:
        os.mkdir(save_model_dir)
    except Exception as e:
        print(e)

    # ==== INIT MODEL=================
    device = get_device()

    model.to(device)
    optimizer = Ranger(model.parameters(), lr=cfg["model_params"]["lr"])

    train_dataloader = load_train_data()

    start = time.time()

    train_result = {
        'epoch': [],
        'iter': [],
        'loss[-k:](avg)': [],
        'validation_loss': [],
        'time(minute)': [],
    }

    k = 1000
    lossk = []
    torch.set_grad_enabled(True)
    num_iter = cfg["train_params"]["max_num_steps"]

    for epoch in range(cfg['train_params']['epoch']):
        model.train()
        torch.set_grad_enabled(True)

        tr_it = iter(train_dataloader)
        train_progress_bar = tqdm(range(num_iter))
        optimizer.zero_grad()
        print('epoch:', epoch)
        for i in train_progress_bar:
            try:
                data = next(tr_it)
                preds, confidences = forward(data, model, device)

                target_availabilities = data["target_availabilities"].to(
                    device)
                targets = data["target_positions"].to(device)

                loss = criterion(targets, preds, confidences,
                                 target_availabilities)

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # total_loss.append(loss.item())
                lossk.append(loss.item())
                if len(lossk) > k:
                    lossk.pop(0)

                train_progress_bar.set_description(
                    f"loss: {loss.item():.4f} loss[-k:](avg): {np.mean(lossk):.4f}"
                )

                if ((i > 0
                     and i % cfg['train_params']['checkpoint_steps'] == 0)
                        or i == num_iter - 1):
                    # save model per checkpoint
                    torch.save(
                        model.state_dict(),
                        f'{save_model_dir}/epoch{epoch:02d}_iter{i:05d}.pth')
                    # save_result
                    train_result['epoch'].append(epoch)
                    train_result['iter'].append(i)
                    train_result['loss[-k:](avg)'].append(np.mean(lossk))
                    train_result['validation_loss'].append(-1)
                    train_result['time(minute)'].append(
                        (time.time() - start) / 60)
                    checkpoint_loss = []

                if i > 0 and i % cfg['train_params']['validation_steps'] == 0:
                    validation_loss = validation(model, device)
                    train_result['epoch'].append(epoch)
                    train_result['iter'].append(-1)
                    train_result['loss[-k:](avg)'].append(-1)
                    train_result['validation_loss'].append(validation_loss)
                    train_result['time(minute)'].append(
                        (time.time() - start) / 60)

                    model.train()
                    torch.set_grad_enabled(True)

            except KeyboardInterrupt:
                torch.save(
                    model.state_dict(),
                    f'{save_model_dir}/interrupt_epoch{epoch:02d}_iter{i:05d}.pth'
                )
                # save train result
                results = pd.DataFrame(train_result)
                results.to_csv(
                    f"experiment/{exp_id}/interrupt_train_result.csv",
                    index=False)
                print(f"Total training time is {(time.time()-start)/60} mins")
                print(results)
                raise KeyboardInterrupt

        del tr_it, train_progress_bar

    # save train result
    results = pd.DataFrame(train_result)
    results.to_csv(f"experiment/{exp_id}/train_result.csv", index=False)
    print(f"Total training time is {(time.time()-start)/60} mins")
    print(results)
Esempio n. 16
0
class Learner:
    def __init__(self, autoencoder: LSTMAutoEncoder, train_loader: DataLoader,
                 val_loader: DataLoader, cfg: Config):
        self.net = autoencoder
        self.train_loader = train_loader
        self.val_loader = val_loader

        # get from config object
        output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.EXPERIMENT_NAME)
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir
        self.device = cfg.DEVICE
        self.epoch_n = cfg.EPOCH_N
        self.save_cycle = cfg.SAVE_CYCLE
        self.verbose_cycle = cfg.VERBOSE_CYCLE
        self.encoder_lr = cfg.ENCODER_LR
        self.decoder_lr = cfg.DECODER_LR
        self.encoder_gamma = cfg.ENCODER_GAMMA
        self.decoder_gamma = cfg.DECODER_GAMMA
        self.encoder_step_cycle = cfg.ENCODER_STEP_CYCLE
        self.decoder_step_cycle = cfg.DECODER_STEP_CYCLE

        # set optimizer and scheduler
        self.encoder_optim = Ranger(params=filter(lambda p: p.requires_grad,
                                                  self.encoder.parameters()),
                                    lr=self.encoder_lr)
        self.decoder_optim = Ranger(params=filter(lambda p: p.requires_grad,
                                                  self.decoder.parameters()),
                                    lr=self.encoder_lr)
        self.encoder_stepper = StepLR(self.encoder_optim,
                                      step_size=self.encoder_step_cycle,
                                      gamma=self.encoder_gamma)
        self.decoder_stepper = StepLR(self.decoder_optim,
                                      step_size=self.decoder_step_cycle,
                                      gamma=self.decoder_gamma)
        self.loss = nn.MSELoss()

        # for book-keeping
        self.crt_epoch = 0
        self.train_losses = []
        self.val_losses = []

    @property
    def encoder(self):
        return self.net.encoder

    @property
    def decoder(self):
        return self.net.decoder

    @property
    def signature(self):
        return f'[Epoch: {self.crt_epoch}]'

    @property
    def model_path(self):
        model_name = f'lstm_ae_{self.crt_epoch:04}.pth'
        _path = os.path.join(self.output_dir, model_name)
        return _path

    @property
    def csv_path(self):
        csv_name = f'report.csv'
        return os.path.join(self.output_dir, csv_name)

    def train(self):
        logger.info('Start training...')

        self.net.to(self.device)

        for epoch_i in range(self.epoch_n):
            self.crt_epoch = epoch_i + 1
            start_t = time.time()

            epoch_min = self._train_one_epoch()
            if self.crt_epoch % self.verbose_cycle == 0:
                train_avg_rmse = self.train_losses[-1]
                logger.info(
                    f'{self.signature}:: Train complete. Avg RMSE: {train_avg_rmse:04f}. Time: {epoch_min:03f} mins'
                )

            epoch_min = self._val_one_epoch()
            if self.crt_epoch % self.verbose_cycle == 0:
                val_avg_rmse = self.val_losses[-1]
                logger.info(
                    f'{self.signature}:: Val complete. Avg RMSE: {val_avg_rmse:04f}. Time: {epoch_min:03f} mins'
                )

            total_epoch_min = (time.time() - start_t) / 60.
            if self.crt_epoch % self.verbose_cycle == 0:
                logger.info(
                    f'{self.signature}:: Completed Train + Val. Time: {total_epoch_min} mins'
                )

            if self.crt_epoch % self.save_cycle == 0:
                self.save_model()
                self.save_report()
                logger.info(
                    f'{self.signature}:: Model saved: {self.model_path}')

        logger.info(f'{self.signature}:: Training complete!')
        self.save_model()
        self.save_report()
        logger.info(
            f'{self.signature}:: Final Model and Report saved: {self.model_path}, {self.csv_path}'
        )

    def _train_one_epoch(self):
        self.net.train()
        start_t = time.time()
        rmse_ls = []

        for bboxs, seq_lens, classes in tqdm(self.train_loader):
            self.encoder_optim.zero_grad()
            self.decoder_optim.zero_grad()

            bboxs = bboxs.to(self.device)
            preds = self.net(bboxs)

            preds = pack_padded_sequence(preds,
                                         seq_lens,
                                         batch_first=True,
                                         enforce_sorted=True)
            targets = pack_padded_sequence(bboxs.clone(),
                                           seq_lens,
                                           batch_first=True,
                                           enforce_sorted=True)
            loss = self.loss(preds.data, targets.data)
            rmse = torch.sqrt(loss)

            rmse.backward()
            self.encoder_optim.step()
            self.decoder_optim.step()
            self.encoder_stepper.step()
            self.decoder_stepper.step()

            report_rmse = float(rmse.data.cpu().numpy())
            rmse_ls.append(report_rmse)

        epoch_min = (time.time() - start_t) / 60.
        avg_rmse = sum(rmse_ls) / len(rmse_ls)
        self.train_losses.append(avg_rmse)
        return epoch_min

    def _val_one_epoch(self):
        self.net.eval()
        start_t = time.time()
        rmse_ls = []

        with torch.no_grad():
            for bboxs, seq_lens, classes in tqdm(self.val_loader):
                bboxs = bboxs.to(self.device)
                preds = self.net(bboxs)

                preds = pack_padded_sequence(preds,
                                             seq_lens,
                                             batch_first=True,
                                             enforce_sorted=True)
                targets = pack_padded_sequence(bboxs.clone(),
                                               seq_lens,
                                               batch_first=True,
                                               enforce_sorted=True)
                loss = self.loss(preds.data, targets.data)
                rmse = torch.sqrt(loss)

                report_rmse = float(rmse.data.cpu().numpy())
                rmse_ls.append(report_rmse)

        epoch_min = (time.time() - start_t) / 60.
        avg_rmse = sum(rmse_ls) / len(rmse_ls)
        self.val_losses.append(avg_rmse)
        return epoch_min

    def save_model(self):
        torch.save(self.net.state_dict(), self.model_path)

    def load_model(self, ckpt_path):
        assert os.path.isfile(ckpt_path), f'Non-exist ckpt_path: {ckpt_path}'
        self.net.load_state_dict(torch.load(ckpt_path))
        logger.info(f'Model loaded: {ckpt_path}')

    def save_report(self):
        losses_dicts = {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses
        }
        df = pd.DataFrame(losses_dicts)
        df.to_csv(self.csv_path, index=False)
Esempio n. 17
0
def train():
    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)

    dataset = COCODetection(image_path=cfg.dataset.train_images,
                            info_file=cfg.dataset.train_info,
                            transform=SSDAugmentation(MEANS))
    
    if args.validation_epoch > 0:
        setup_eval()
        val_dataset = COCODetection(image_path=cfg.dataset.valid_images,
                                    info_file=cfg.dataset.valid_info,
                                    transform=BaseTransform(MEANS))

    # Parallel wraps the underlying module, but when saving and loading we don't want that
    yolact_net = Yolact()
    net = yolact_net
    net.train()
    print('\n--- Generator created! ---')

    # NOTE
    # I maunally set the original image size and seg size as 138
    # might change in the future, for example 550
    if cfg.pred_seg:
        dis_size = 138
        dis_net  = Discriminator_Wgan(i_size = dis_size, s_size = dis_size)
        # Change the initialization inside the dis_net class inside 
        # set the dis net's initial parameter values
        # dis_net.apply(gan_init)
        dis_net.train()
        print('--- Discriminator created! ---\n')

    if args.log:
        log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()),
            overwrite=(args.resume is None), log_gpu_stats=args.log_gpu)

    # I don't use the timer during training (I use a different timing method).
    # Apparently there's a race condition with multiple GPUs, so disable it just to be safe.
    timer.disable_all()

    # Both of these can set args.resume to None, so do them before the check    
    if args.resume == 'interrupt':
        args.resume = SavePath.get_interrupt(args.save_folder)
    elif args.resume == 'latest':
        args.resume = SavePath.get_latest(args.save_folder, cfg.name)

    if args.resume is not None:
        print('Resuming training, loading {}...'.format(args.resume))
        yolact_net.load_weights(args.resume)

        if args.start_iter == -1:
            args.start_iter = SavePath.from_str(args.resume).iteration
    else:
        print('Initializing weights...')
        yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path)

    # optimizer_gen = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
    #                       weight_decay=args.decay)
    # if cfg.pred_seg:
    #     optimizer_dis = optim.SGD(dis_net.parameters(), lr=cfg.dis_lr, momentum=args.momentum,
    #                         weight_decay=args.decay)
    #     schedule_dis  = ReduceLROnPlateau(optimizer_dis, mode = 'min', patience=6, min_lr=1E-6)

    # NOTE: Using the Ranger Optimizer for the generator
    optimizer_gen     = Ranger(net.parameters(), lr = args.lr, weight_decay=args.decay)
    # optimizer_gen = optim.RMSprop(net.parameters(), lr = args.lr)

    # FIXME: Might need to modify the lr in the optimizer carefually
    # check this
    # def make_D_optimizer(cfg, model):
    # params = []
    # for key, value in model.named_parameters():
    #     if not value.requires_grad:
    #         continue
    #     lr = cfg.SOLVER.BASE_LR/5.0
    #     weight_decay = cfg.SOLVER.WEIGHT_DECAY
    #     if "bias" in key:
    #         lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR/5.0
    #         weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
    #     params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

    # optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM)
    # return optimizer

    if cfg.pred_seg:
        optimizer_dis = optim.SGD(dis_net.parameters(), lr=cfg.dis_lr)
        # optimizer_dis = optim.RMSprop(dis_net.parameters(), lr = cfg.dis_lr)
        schedule_dis  = ReduceLROnPlateau(optimizer_dis, mode = 'min', patience=6, min_lr=1E-6)

    criterion     = MultiBoxLoss(num_classes=cfg.num_classes,
                                pos_threshold=cfg.positive_iou_threshold,
                                neg_threshold=cfg.negative_iou_threshold,
                                negpos_ratio=cfg.ohem_negpos_ratio, pred_seg=cfg.pred_seg)

    # criterion_dis = nn.BCELoss()
    # Take the advice from WGAN
    criterion_dis = DiscriminatorLoss_Maskrcnn()
    criterion_gen = GeneratorLoss_Maskrcnn()


    if args.batch_alloc is not None:
        # e.g. args.batch_alloc: 24,24
        args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')]
        if sum(args.batch_alloc) != args.batch_size:
            print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size))
            exit(-1)

    net = CustomDataParallel(NetLoss(net, criterion, pred_seg=cfg.pred_seg))

    if args.cuda:
        net     = net.cuda()
        # NOTE
        if cfg.pred_seg:
            dis_net = nn.DataParallel(dis_net)
            dis_net = dis_net.cuda()
    
    # Initialize everything
    if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means
    yolact_net(torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda())

    if not cfg.freeze_bn: yolact_net.freeze_bn(True)

    # loss counters
    loc_loss = 0
    conf_loss = 0
    iteration = max(args.start_iter, 0)
    last_time = time.time()

    epoch_size = len(dataset) // args.batch_size
    num_epochs = math.ceil(cfg.max_iter / epoch_size)
    
    # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index
    step_index = 0

    data_loader = data.DataLoader(dataset, args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True, collate_fn=detection_collate,
                                  pin_memory=True)
    # NOTE
    val_loader  = data.DataLoader(val_dataset, args.batch_size,
                                  num_workers=args.num_workers*2,
                                  shuffle=True, collate_fn=detection_collate,
                                  pin_memory=True)
    
    
    save_path = lambda epoch, iteration: SavePath(cfg.name, epoch, iteration).get_path(root=args.save_folder)
    time_avg = MovingAverage()

    global loss_types # Forms the print order
                      # TODO: global command can modify global variable inside of the function.
    loss_avgs  = { k: MovingAverage(100) for k in loss_types }

    # NOTE
    # Enable AMP
    amp_enable = cfg.amp
    scaler = torch.cuda.amp.GradScaler(enabled=amp_enable)

    print('Begin training!')
    print()
    # try-except so you can use ctrl+c to save early and stop training
    try:
        for epoch in range(num_epochs):
            # Resume from start_iter

            if (epoch+1)*epoch_size < iteration:
                continue
            
            for datum in data_loader:
                # Stop if we've reached an epoch if we're resuming from start_iter
                if iteration == (epoch+1)*epoch_size:
                    break
      
                # Stop at the configured number of iterations even if mid-epoch
                if iteration == cfg.max_iter:
                    break

                # Change a config setting if we've reached the specified iteration
                changed = False
                for change in cfg.delayed_settings:
                    if iteration >= change[0]:
                        changed = True
                        cfg.replace(change[1])

                        # Reset the loss averages because things might have changed
                        for avg in loss_avgs:
                            avg.reset()
                
                # If a config setting was changed, remove it from the list so we don't keep checking
                if changed:
                    cfg.delayed_settings = [x for x in cfg.delayed_settings if x[0] > iteration]

                # Warm up by linearly interpolating the learning rate from some smaller value
                if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until:
                    set_lr(optimizer_gen, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init)

                # Adjust the learning rate at the given iterations, but also if we resume from past that iteration
                while step_index < len(cfg.lr_steps) and iteration >= cfg.lr_steps[step_index]:
                    step_index += 1
                    set_lr(optimizer_gen, args.lr * (args.gamma ** step_index))
                
                
                # NOTE
                if cfg.pred_seg:
                    # ====== GAN Train ======
                    # train the gen and dis in different iteration
                    # it_alter_period = iteration % (cfg.gen_iter + cfg.dis_iter)
                    # FIXME:
                    # present_time = time.time()
                    for _ in range(cfg.dis_iter):
                        # freeze_pretrain(yolact_net, freeze=False)
                        # freeze_pretrain(net, freeze=False)
                        # freeze_pretrain(dis_net, freeze=False)
                        # if it_alter_period == 0:
                        #     print('--- Generator     freeze   ---')
                        #     print('--- Discriminator training ---')

                        if cfg.amp:
                            with torch.cuda.amp.autocast():
                                # ----- Discriminator part -----
                                # seg_list  is the prediction mask
                                # can be regarded as generated images from YOLACT
                                # pred_list is the prediction label
                                # seg_list  dim: list of (138,138,instances)
                                # pred_list dim: list of (instances)
                                losses, seg_list, pred_list = net(datum)
                                seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum)
                                
                                # input image size is [b, 3, 550, 550]
                                # downsample to       [b, 3, seg_h, seg_w]
                                image_list = [img.to(cuda0) for img in datum[0]]
                                image    = interpolate(torch.stack(image_list), size = seg_size, 
                                                            mode='bilinear',align_corners=False)

                                # Because in the discriminator training, we do not 
                                # want the gradient flow back to the generator part
                                # we detach seg_clas (mask_clas come the data, does not have grad)
                        
                                output_pred = dis_net(img = image.detach(), seg = seg_clas.detach())
                                output_grou = dis_net(img = image.detach(), seg = mask_clas.detach())

                                # p = elem_mul_p.squeeze().permute(1,2,0).cpu().detach().numpy()
                                # g = elem_mul_g.squeeze().permute(1,2,0).cpu().detach().numpy()
                                # image = image.squeeze().permute(1,2,0).cpu().detach().numpy()
                                # from PIL import Image
                                # seg_PIL = Image.fromarray(p, 'RGB')
                                # mask_PIL = Image.fromarray(g, 'RGB')
                                # seg_PIL.save('mul_seg.png')
                                # mask_PIL.save('mul_mask.png')
                                # raise RuntimeError

                                # from matplotlib import pyplot as plt
                                # fig, (ax1, ax2) = plt.subplots(1,2)
                                # ax1.imshow(mask_show)
                                # ax2.imshow(seg_show)
                                # plt.show(block=False)
                                # plt.pause(2)
                                # plt.close()  

                                # if iteration % (cfg.gen_iter + cfg.dis_iter) == 0:
                                #     print(f'Probability of fake is fake: {output_pred.mean().item():.2f}')
                                #     print(f'Probability of real is real: {output_grou.mean().item():.2f}')

                                # 0 for Fake/Generated
                                # 1 for True/Ground Truth
                                # fake_label = torch.zeros(b)
                                # real_label = torch.ones(b)

                                # Advice of practical implementation 
                                # from https://arxiv.org/abs/1611.08408
                                # loss_pred = -criterion_dis(output_pred,target=real_label)
                                # loss_pred = criterion_dis(output_pred,target=fake_label)
                                # loss_grou = criterion_dis(output_grou,target=real_label)
                                # loss_dis  = loss_pred + loss_grou

                                # Wasserstein Distance (Earth-Mover)
                                loss_dis = criterion_dis(input=output_grou,target=output_pred)
                            
                            # Backprop the discriminator
                            # Scales loss. Calls backward() on scaled loss to create scaled gradients.
                            scaler.scale(loss_dis).backward()
                            scaler.step(optimizer_dis)
                            scaler.update()
                            optimizer_dis.zero_grad()

                            # clip the updated parameters
                            _ = [par.data.clamp_(-cfg.clip_value, cfg.clip_value) for par in dis_net.parameters()]


                            # ----- Generator part -----
                            # freeze_pretrain(yolact_net, freeze=False)
                            # freeze_pretrain(net, freeze=False)
                            # freeze_pretrain(dis_net, freeze=False)
                            # if it_alter_period == (cfg.dis_iter+1):
                            #     print('--- Generator     training ---')
                            #     print('--- Discriminator freeze   ---')

                            # FIXME:
                            # print(f'dis time pass: {time.time()-present_time:.2f}')
                            # FIXME:
                            # present_time = time.time()

                            with torch.cuda.amp.autocast():
                                losses, seg_list, pred_list = net(datum)
                                seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum)
                                image_list = [img.to(cuda0) for img in datum[0]]
                                image      = interpolate(torch.stack(image_list), size = seg_size, 
                                                            mode='bilinear',align_corners=False)
                                # Perform forward pass of all-fake batch through D
                                # NOTE this seg_clas CANNOT detach, in order to flow the 
                                # gradient back to the generator
                                # output = dis_net(img = image, seg = seg_clas)
                                # Since the log(1-D(G(x))) not provide sufficient gradients
                                # We want log(D(G(x)) instead, this can be achieve by
                                # use the real_label as target.
                                # This step is crucial for the information of discriminator
                                # to go into the generator.
                                # Calculate G's loss based on this output
                                # real_label = torch.ones(b)
                                # loss_gen   = criterion_dis(output,target=real_label)
                            
                                # GAN MaskRCNN
                                output_pred = dis_net(img = image, seg = seg_clas)
                                output_grou = dis_net(img = image, seg = mask_clas)

                                # Advice from WGAN
                                # loss_gen = -torch.mean(output)
                                loss_gen = criterion_gen(input=output_grou,target=output_pred)

                                # since the dis is already freeze, the gradients will only
                                # record the YOLACT
                                losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel
                                loss = sum([losses[k] for k in losses])
                                loss += loss_gen
                            
                            # Generator backprop
                            scaler.scale(loss).backward()
                            scaler.step(optimizer_gen)
                            scaler.update()
                            optimizer_gen.zero_grad()
                            

                            # FIXME:
                            # print(f'gen time pass: {time.time()-present_time:.2f}')
                            # print('GAN part over')

                        else:
                            losses, seg_list, pred_list = net(datum)
                            seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum)

                            image_list = [img.to(cuda0) for img in datum[0]]
                            image    = interpolate(torch.stack(image_list), size = seg_size, 
                                                        mode='bilinear',align_corners=False)

                            output_pred = dis_net(img = image.detach(), seg = seg_clas.detach())
                            output_grou = dis_net(img = image.detach(), seg = mask_clas.detach())
                            loss_dis = criterion_dis(input=output_grou,target=output_pred)

                            loss_dis.backward()
                            optimizer_dis.step()
                            optimizer_dis.zero_grad()
                            _ = [par.data.clamp_(-cfg.clip_value, cfg.clip_value) for par in dis_net.parameters()]
                        
                            # ----- Generator part -----
                            # FIXME:
                            # print(f'dis time pass: {time.time()-present_time:.2f}')
                            # FIXME:
                            # present_time = time.time()

                            losses, seg_list, pred_list = net(datum)
                            seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum)
                            image_list = [img.to(cuda0) for img in datum[0]]
                            image      = interpolate(torch.stack(image_list), size = seg_size, 
                                                        mode='bilinear',align_corners=False)
                                                        
                            # GAN MaskRCNN
                            output_pred = dis_net(img = image, seg = seg_clas)
                            output_grou = dis_net(img = image, seg = mask_clas)

                            loss_gen = criterion_gen(input=output_grou,target=output_pred)

                            # since the dis is already freeze, the gradients will only
                            # record the YOLACT
                            losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel
                            loss = sum([losses[k] for k in losses])
                            loss += loss_gen
                            loss.backward()
                            # Do this to free up vram even if loss is not finite
                            optimizer_gen.zero_grad()
                            if torch.isfinite(loss).item():
                                # since the optimizer_gen is for YOLACT only
                                # only the gen will be updated
                                optimizer_gen.step()       

                            # FIXME:
                            # print(f'gen time pass: {time.time()-present_time:.2f}')
                            # print('GAN part over')
                else:
                    # ====== Normal YOLACT Train ======
                    # Zero the grad to get ready to compute gradients
                    optimizer_gen.zero_grad()
                    # Forward Pass + Compute loss at the same time (see CustomDataParallel and NetLoss)
                    losses = net(datum)
                    losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel
                    loss = sum([losses[k] for k in losses])
                    # no_inf_mean removes some components from the loss, so make sure to backward through all of it
                    # all_loss = sum([v.mean() for v in losses.values()])

                    # Backprop
                    loss.backward() # Do this to free up vram even if loss is not finite
                    if torch.isfinite(loss).item():
                        optimizer_gen.step()                    
                
                # Add the loss to the moving average for bookkeeping
                _ = [loss_avgs[k].add(losses[k].item()) for k in losses]
                # for k in losses:
                #     loss_avgs[k].add(losses[k].item())

                cur_time  = time.time()
                elapsed   = cur_time - last_time
                last_time = cur_time

                # Exclude graph setup from the timing information
                if iteration != args.start_iter:
                    time_avg.add(elapsed)

                if iteration % 10 == 0:
                    eta_str = str(datetime.timedelta(seconds=(cfg.max_iter-iteration) * time_avg.get_avg())).split('.')[0]
                    
                    total = sum([loss_avgs[k].get_avg() for k in losses])
                    loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], [])
                    if cfg.pred_seg:
                        print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f')
                                % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True)
                        # print(f'Generator loss: {loss_gen:.2f} | Discriminator loss: {loss_dis:.2f}')
                    # Loss Key:
                    #  - B: Box Localization Loss
                    #  - C: Class Confidence Loss
                    #  - M: Mask Loss
                    #  - P: Prototype Loss
                    #  - D: Coefficient Diversity Loss
                    #  - E: Class Existence Loss
                    #  - S: Semantic Segmentation Loss
                    #  - T: Total loss

                if args.log:
                    precision = 5
                    loss_info = {k: round(losses[k].item(), precision) for k in losses}
                    loss_info['T'] = round(loss.item(), precision)

                    if args.log_gpu:
                        log.log_gpu_stats = (iteration % 10 == 0) # nvidia-smi is sloooow
                        
                    log.log('train', loss=loss_info, epoch=epoch, iter=iteration,
                        lr=round(cur_lr, 10), elapsed=elapsed)

                    log.log_gpu_stats = args.log_gpu
                
                iteration += 1

                if iteration % args.save_interval == 0 and iteration != args.start_iter:
                    if args.keep_latest:
                        latest = SavePath.get_latest(args.save_folder, cfg.name)

                    print('Saving state, iter:', iteration)
                    yolact_net.save_weights(save_path(epoch, iteration))

                    if args.keep_latest and latest is not None:
                        if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval:
                            print('Deleting old save...')
                            os.remove(latest)
            
            # This is done per epoch
            if args.validation_epoch > 0:
                # NOTE: Validation loss
                # if cfg.pred_seg:
                #     net.eval()
                #     dis_net.eval()
                #     cfg.gan_eval = True
                #     with torch.no_grad():
                #         for datum in tqdm(val_loader, desc='GAN Validation'):
                #             losses, seg_list, pred_list = net(datum)
                #             losses, seg_list, pred_list = net(datum)
                #             # TODO: warp below as a function
                #             seg_list = [v.permute(2,1,0).contiguous() for v in seg_list]
                #             b = len(seg_list) # batch size
                #             _, seg_h, seg_w = seg_list[0].size()

                #             seg_clas    = torch.zeros(b, cfg.num_classes-1, seg_h, seg_w)
                #             mask_clas   = torch.zeros(b, cfg.num_classes-1, seg_h, seg_w)
                #             target_list = [target for target in datum[1][0]]
                #             mask_list   = [interpolate(mask.unsqueeze(0), size = (seg_h,seg_w),mode='bilinear', \
                #                             align_corners=False).squeeze() for mask in datum[1][1]]

                #             for idx in range(b):
                #                 for i, (pred, i_target) in enumerate(zip(pred_list[idx], target_list[idx])):
                #                     seg_clas[idx, pred, ...]                 += seg_list[idx][i,...]
                #                     mask_clas[idx, i_target[-1].long(), ...] += mask_list[idx][i,...]
                               
                #             seg_clas = torch.clamp(seg_clas, 0, 1)
                #             image    = interpolate(torch.stack(datum[0]), size = (seg_h,seg_w), 
                #                                         mode='bilinear',align_corners=False)
                #             real_label  = torch.ones(b)
                #             output_pred = dis_net(img = image, seg = seg_clas)
                #             output_grou = dis_net(img = image, seg = mask_clas)
                #             loss_pred   = -criterion_dis(output_pred,target=real_label)
                #             loss_grou   =  criterion_dis(output_grou,target=real_label)
                #             loss_dis    = loss_pred + loss_grou
                #         losses = { k: (v).mean() for k,v in losses.items() }
                #         loss = sum([losses[k] for k in losses])
                #         val_loss = loss - cfg.lambda_dis*loss_dis
                #         schedule_dis.step(loss_dis)
                #         lr = [group['lr'] for group in optimizer_dis.param_groups]
                #         print(f'Discriminator lr: {lr[0]}')
                #     net.train()
                if epoch % args.validation_epoch == 0 and epoch > 0:
                    cfg.gan_eval = False
                    dis_net.eval()
                    compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None)
        
        # Compute validation mAP after training is finished
        compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None)
    except KeyboardInterrupt:
        if args.interrupt:
            print('Stopping early. Saving network...')
            
            # Delete previous copy of the interrupted network so we don't spam the weights folder
            SavePath.remove_interrupt(args.save_folder)
            
            yolact_net.save_weights(save_path(epoch, repr(iteration) + '_interrupt'))
        exit()

    yolact_net.save_weights(save_path(epoch, iteration))
Esempio n. 18
0
def main(args, logger):
    writer = SummaryWriter(args.subTensorboardDir)
    model = Vgg().to(device)
    trainSet = Lung(rootDir=args.dataDir, mode='train', size=args.inputSize)
    valSet = Lung(rootDir=args.dataDir, mode='test', size=args.inputSize)
    trainDataloader = DataLoader(trainSet,
                                 batch_size=args.batchSize,
                                 drop_last=True,
                                 shuffle=True,
                                 pin_memory=False,
                                 num_workers=args.numWorkers)
    valDataloader = DataLoader(valSet,
                               batch_size=args.valBatchSize,
                               drop_last=False,
                               shuffle=False,
                               pin_memory=False,
                               num_workers=args.numWorkers)
    criterion = nn.CrossEntropyLoss()
    optimizer = Ranger(model.parameters(), lr=args.lr)
    model, optimizer = amp.initialize(model, optimizer, opt_level=args.apexType)
    iter = 0
    runningLoss = []
    for epoch in range(args.epoch):
        if epoch != 0 and epoch % args.evalFrequency == 0:
            f1, acc = eval(model, valDataloader, logger)
            writer.add_scalars('f1_acc', {'f1': f1,
                                          'acc': acc}, iter)

        if epoch != 0 and epoch % args.saveFrequency == 0:
            modelName = osp.join(args.subModelDir, 'out_{}.pt'.format(epoch))
            # 防止分布式训练保存失败
            stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict()
            torch.save(stateDict, modelName)
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict()
            }
            torch.save(checkpoint, modelName)

        for img, lb, _ in trainDataloader:
            # array = np.array(img)
            # for i in range(array.shape[0]):
            #     plt.imshow(array[i, 0, ...], cmap='gray')
            #     plt.show()
            iter += 1
            img, lb = img.to(device), lb.to(device)
            optimizer.zero_grad()
            outputs = model(img)
            loss = criterion(outputs.squeeze(), lb.long())
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            # loss.backward()
            optimizer.step()
            runningLoss.append(loss.item())

            if iter % args.msgFrequency == 0:
                avgLoss = np.mean(runningLoss)
                runningLoss = []
                lr = optimizer.param_groups[0]['lr']
                logger.info(f'epoch: {epoch} / {args.epoch}, '
                            f'iter: {iter} / {len(trainDataloader) * args.epoch}, '
                            f'lr: {lr}, '
                            f'loss: {avgLoss:.4f}')
                writer.add_scalar('loss', avgLoss, iter)

    eval(model, valDataloader, logger)
    modelName = osp.join(args.subModelDir, 'final.pth')
    stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict()
    torch.save(stateDict, modelName)
Esempio n. 19
0
            #print(target.shape)
            #print(inputx.shape)
            output = model(inputs)

            ############################ 数据出口#########################################
            ############  请将你的output数据重新转换成 bscwh格式进行下一步#################
            #output = output.unsqueeze(2)

            ###############################################################################
            #loss1 = criterionmse(output, targets)
            #loss2 = criterionmae(output, targets)
            print(targets.shape)
            loss = criterionbmse(output, targets)
            #loss3 = criteriongdl(output,target)
            #loss = loss1+loss2
            optimizer.zero_grad()
            loss.backward()
            losses.update(loss.item())
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optimizer.step()
            #t_rmse = rmse_loss(output, targets)
            #rmse.update(t_rmse.item())

            #output_np = np.clip(output.detach().cpu().numpy(), 0, 1)
            #target_np = np.clip(targets.detach().cpu().numpy(), 0, 1)
            logging.info('[{0}][{1}][{2}]\t'
                         'lr: {lr:.5f}\t'
                         'loss: {loss.val:.6f} ({loss.avg:.6f})'.format(
                             epoch,
                             headid,
                             ind,
Esempio n. 20
0
                vutils.save_image(lab, "%s/debugUnetLabel.png" % (savePath),
                                  normalize=False)
                print ("initial label",lab.sum(),lab.max(),lab.mean())

                plt.figure(figsize=(12, 12))
                sr=lab.view(-1)
                sr=torch.sort(sr)[0]
                print ("sr",sr.shape)
                plt.plot(sr[-1000:].numpy())
                plt.savefig("%s/labvalues_px%s_cr%s_BN%s.png" % (savePath, opt.imageSize, opt.imageCrop,str(opt.BN)))
                plt.close()
                #raise Exception

            lab = lab.to(device).float()  # not long but float...

            opti.zero_grad()
            pred = netD(im)
            err = criterion(pred, lab)

            err.backward()
            opti.step()

            buf.append(err.item())

            for j in range(lab.shape[0]):##single element, 1 number per patch
                if lab[j].sum()==0:
                    bufN.append(pred[j].max().item())
                    bufME.append(pred[j].mean().item())
                else:
                    bufP.append(pred[j].max().item())
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]

            image_tmp, _ = iter(self.train_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

            if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR':
                self.num_train = len(self.train_loader.sampler.indices)
                self.num_valid = len(self.valid_loader.sampler.indices)
            elif config.dataset_name == 'ImageNet':
                # the ImageNet cannot be sampled, otherwise this part will be wrong.
                self.num_train = 100000  #len(train_dataset) in data_loader.py, wrong: len(self.train_loader)
                self.num_valid = 10000  #len(self.valid_loader)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)

            image_tmp, _ = iter(self.test_loader).next()
            self.image_size = (image_tmp.shape[2], image_tmp.shape[3])

        # assign numer of channels and classes of images in this dataset, maybe there is more robust way
        if 'MNIST' in config.dataset_name:
            self.num_channels = 1
            self.num_classes = 10
        elif config.dataset_name == 'ImageNet':
            self.num_channels = 3
            self.num_classes = 1000
        elif config.dataset_name == 'CIFAR':
            self.num_channels = 3
            self.num_classes = 10

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.loss_fun_baseline = config.loss_fun_baseline
        self.loss_fun_action = config.loss_fun_action
        self.weight_decay = config.weight_decay

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.best_train_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq

        if config.use_gpu:
            self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)
        else:
            self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format(
                config.PBSarray_ID, config.num_glimpses, config.patch_size,
                config.patch_size, config.hidden_size, config.std,
                config.dropout)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir, exist_ok=True)

        # configure tensorboard logging
        if self.use_tensorboard:
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)
            writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name)

        # build DRAMBUTD model
        self.model = RecurrentAttention(self.patch_size, self.num_channels,
                                        self.image_size, self.std,
                                        self.hidden_size, self.num_classes,
                                        config)
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        if config.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay)
        elif config.optimizer == 'ReduceLROnPlateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               'min',
                                               patience=self.lr_patience,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Adadelta':
            self.optimizer = optim.Adadelta(self.model.parameters(),
                                            weight_decay=self.weight_decay)
        elif config.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=3e-4,
                                        weight_decay=self.weight_decay)
        elif config.optimizer == 'AdaBound':
            self.optimizer = adabound.AdaBound(self.model.parameters(),
                                               lr=3e-4,
                                               final_lr=0.1,
                                               weight_decay=self.weight_decay)
        elif config.optimizer == 'Ranger':
            self.optimizer = Ranger(self.model.parameters(),
                                    weight_decay=self.weight_decay)

    def reset(self, x, SM):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)
        #
        h_t2, l_t, SM_local_smooth = self.model.initialize(x, SM)

        # initialize hidden state 1 as 0 vector to avoid the directly classification from context
        h_t1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype)

        cell_state1 = torch.zeros(self.batch_size,
                                  self.hidden_size).type(dtype)

        cell_state2 = torch.zeros(self.batch_size,
                                  self.hidden_size).type(dtype)

        return h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # # reduce lr if validation loss plateaus
            # self.scheduler.step(valid_loss)

            is_best_valid = valid_acc > self.best_valid_acc
            is_best_train = train_acc > self.best_train_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"

            if is_best_train:
                msg1 += " [*]"

            if is_best_valid:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best_valid:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.best_train_acc = max(train_acc, self.best_train_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                    'best_train_acc': self.best_train_acc,
                }, is_best_valid)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x_raw, y) in enumerate(self.train_loader):
                #
                if self.use_gpu:
                    x_raw, y = x_raw.cuda(), y.cuda()

                # detach images and their saliency maps
                x = x_raw[:, 0, ...].unsqueeze(1)
                SM = x_raw[:, 1, ...].unsqueeze(1)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                    x, SM)
                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []

                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                        x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                        SM_local_smooth)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x,
                    l_t,
                    h_t1,
                    h_t2,
                    cell_state1,
                    cell_state2,
                    SM,
                    SM_local_smooth,
                    last=True)

                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                if self.loss_fun_baseline == 'cross_entropy':
                    # cross_entroy_loss need a long, batch x 1 tensor as target but R
                    # also need to be subtracted by the baseline whose size is N x num_glimpse
                    R = (predicted.detach() == y).long()
                    # compute losses for differentiable modules
                    loss_action, loss_baseline = self.choose_loss_fun(
                        log_probas, y, baselines, R)
                    R = R.float().unsqueeze(1).repeat(1, self.num_glimpses)
                else:
                    R = (predicted.detach() == y).float()
                    R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                    # compute losses for differentiable modules
                    loss_action, loss_baseline = self.choose_loss_fun(
                        log_probas, y, baselines, R)

                # loss_action = F.nll_loss(log_probas, y)
                # loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                #losses.update(loss.data[0], x.size()[0])
                #accs.update(acc.data[0], x.size()[0])
                losses.update(loss.data.item(), x.size()[0])
                accs.update(acc.data.item(), x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data.item(), acc.data.item())))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))
                    sio.savemat(self.plot_dir +
                                "data_train_{}.mat".format(epoch + 1),
                                mdict={
                                    'location': locs,
                                    'patch': imgs
                                })

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    writer.add_scalar('Loss/train', losses, iteration)
                    writer.add_scalar('Accuracy/train', accs, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x_raw, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x_raw, y = x_raw.cuda(), y.cuda()

            # detach images and their saliency maps
            x = x_raw[:, 0, ...].unsqueeze(1)
            SM = x_raw[:, 1, ...].unsqueeze(1)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)
            SM = SM.repeat(self.M, 1, 1, 1)
            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                x, SM)

            # extract the glimpses
            log_pi = []
            baselines = []

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                    SM_local_smooth)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                x,
                l_t,
                h_t1,
                h_t2,
                cell_state1,
                cell_state2,
                SM,
                SM_local_smooth,
                last=True)

            # store
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            if self.loss_fun_baseline == 'cross_entropy':
                # cross_entroy_loss need a long, batch x 1 tensor as target but R
                # also need to be subtracted by the baseline whose size is N x num_glimpse
                R = (predicted.detach() == y).long()
                # compute losses for differentiable modules
                loss_action, loss_baseline = self.choose_loss_fun(
                    log_probas, y, baselines, R)
                R = R.float().unsqueeze(1).repeat(1, self.num_glimpses)
            else:
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)
                # compute losses for differentiable modules
                loss_action, loss_baseline = self.choose_loss_fun(
                    log_probas, y, baselines, R)

            # compute losses for differentiable modules
            # loss_action = F.nll_loss(log_probas, y)
            # loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data.item(), x.size()[0])
            accs.update(acc.data.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                writer.add_scalar('Accuracy/valid', accs, iteration)
                writer.add_scalar('Loss/valid', losses, iteration)

        return losses.avg, accs.avg

    def choose_loss_fun(self, log_probas, y, baselines, R):
        """
        use disctionary to save function handle
        replacement of swith-case

        be careful of the argument data type and shape!!!
        """
        loss_fun_pool = {
            'mse': F.mse_loss,
            'l1': F.l1_loss,
            'nll': F.nll_loss,
            'smooth_l1': F.smooth_l1_loss,
            'kl_div': F.kl_div,
            'cross_entropy': F.cross_entropy
        }

        return loss_fun_pool[self.loss_fun_action](
            log_probas, y), loss_fun_pool[self.loss_fun_baseline](baselines, R)

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset(
                x, SM)

            # save images and glimpse location
            locs = []
            imgs = []
            imgs.append(x[0:9])

            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                    x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM,
                    SM_local_smooth)

                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model(
                x,
                l_t,
                h_t1,
                h_t2,
                cell_state1,
                cell_state2,
                SM,
                SM_local_smooth,
                last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

            # dump test data
            if self.use_gpu:
                imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                locs = [l.cpu().data.numpy() for l in locs]
            else:
                imgs = [g.data.numpy().squeeze() for g in imgs]
                locs = [l.data.numpy() for l in locs]

            pickle.dump(imgs, open(self.plot_dir + "g_test.p", "wb"))

            pickle.dump(locs, open(self.plot_dir + "l_test.p", "wb"))
            sio.savemat(self.plot_dir + "test_transient.mat",
                        mdict={'location': locs})

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
Esempio n. 22
0
def train(model,
          dataset,
          val_X,
          val_y,
          batch_size=512,
          epochs=500,
          epoch_show=1,
          weight_decay=1e-5,
          momentum=0.9):
    train_loader = DataLoader(dataset=dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)
    optimizer = Ranger(model.parameters(), weight_decay=weight_decay)
    # optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    #     optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=momentum, weight_decay=weight_decay)
    #     min_lr = 1e-4
    #     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
    #                                                      mode='min',
    #                                                      factor=0.1,
    #                                                      patience=10,
    #                                                      threshold=0,
    #                                                      min_lr=min_lr,
    #                                                      verbose=True)
    loss_fn = nn.MSELoss(reduction='mean')

    writer = SummaryWriter(
        'runs/batchSize/linear_hidden={}_neurons={}_{}_batch={:04d}_bn={}_p={:.1f}_mom={}_l2={}'
        .format(
            len(model.features) - 1, model.features[-1],
            model.activation_fn_name, batch_size, model.bn, model.p, momentum,
            weight_decay))

    best_val_mape = 100
    for epoch in range(epochs):
        # if get_lr(optimizer) < 2 * min_lr:
        #     break
        model.train()
        loss_epoch = 0
        for i, data in enumerate(train_loader, 0):
            x, y = data
            out = model(x)
            loss_iter = loss_fn(y, out.squeeze())
            optimizer.zero_grad()
            loss_iter.backward()
            optimizer.step()
            loss_epoch += loss_iter
        if epoch % epoch_show == 0 or epoch == epochs - 1:
            with torch.no_grad():
                model.eval()
                pred = model(val_X)
                loss_val = loss_fn(val_y, pred.squeeze())
                mape = mean_absolute_percentage_error(val_y, pred.squeeze())
                #                 scheduler.step(best_val_mape)
                print(
                    '\nEpoch {:03d}: loss_train={:.6f}, loss_val={:.6f}, val_mape={:.4f}, '
                    'best_val_mape={:.4f}'.format(epoch, loss_epoch / (i + 1),
                                                  loss_val, mape,
                                                  best_val_mape),
                    end='  ')
                if mape < best_val_mape:
                    model_name = 'running_best_model.pt'
                    print(
                        'Val_mape improved from {:.4f} to {:.4f}, saving model to {}'
                        .format(best_val_mape, mape, model_name),
                        end=' ')
                    best_val_mape = mape
                    torch.save(model.state_dict(), 'models/' + model_name)
        writer.add_scalar("loss/train", loss_epoch / (i + 1), epoch)
        writer.add_scalar("loss/val", loss_val, epoch)
        writer.add_scalar("mape/val", mape, epoch)
    torch.save(model.state_dict(), 'models/final_model.pt')
    writer.add_scalar("mape/best_val", best_val_mape)
    return writer
Esempio n. 23
0
class Brain(object):
    """
    High-level model logic and tuning nuggets encapsulated.

    Based on efficientNet: https://arxiv.org/abs/1905.11946
    fine tuning the efficientnet for classification and object detection
    in this implementation, no weights are frozen
    ideally, batchnorm layers can be frozen for marginal training speed increase
    """

    def __init__(self, gradient_accum_steps=5, lr=0.0005, epochs=100, n_class=2, lmb=30):
        self.device = self.set_cuda_device()
        self.net = EFN_Classifier("tf_efficientnet_b1_ns", n_class).to(self.device)
        self.loss_function = nn.MSELoss()
        self.clf_loss_function = nn.CrossEntropyLoss()
        self.optimizer = Ranger(self.net.parameters(), lr=lr, weight_decay=0.999, betas=(0.9, 0.999))
        self.scheduler = CosineAnnealingLR(self.optimizer, epochs * 0.5, lr * 0.0001)
        self.scheduler.last_epoch = epochs
        self.scaler = GradScaler()
        self.epochs = epochs
        self.gradient_accum_steps = gradient_accum_steps
        self.lmb = lmb

    @staticmethod
    def set_cuda_device():
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
            logging.info(f"Running on {torch.cuda.get_device_name()}")
        else:
            device = torch.device("cpu")
            logging.info("Running on a CPU")
        return device

    def run_training_loop(self, train_dataloader, valid_dataloader, model_filename):
        best_loss = float("inf")

        for epoch in range(self.epochs):
            if epoch != 0 and epoch > 0.5 * self.epochs:  # cosine anneal the last 50% of epochs
                self.scheduler.step()
            logging.info(f"Epoch {epoch+1}")

            logging.info("Training")
            train_losses, train_accuracies, train_miou = self.forward_pass(train_dataloader, train=True)

            logging.info("Validating")
            val_losses, val_accuracies, val_miou = self.forward_pass(valid_dataloader)

            logging.info(
                f"Training accuracy: {sum(train_accuracies)/len(train_accuracies):.2f}"
                f" | Training loss: {sum(train_losses)/len(train_losses):.2f}"
                f" | Training mIoU: {sum(train_miou)/len(train_miou):.2f}"
            )
            logging.info(
                f"Validation accuracy: {sum(val_accuracies)/len(val_accuracies):.2f}"
                f" | Validation loss: {sum(val_losses)/len(val_losses):.2f}"
                f" | Validation mIoU: {sum(val_miou)/len(val_miou):.2f}"
            )

            epoch_val_loss = sum(val_losses) / len(val_losses)

            if best_loss > epoch_val_loss:
                best_loss = epoch_val_loss
                torch.save(self.net.state_dict(), model_filename)
                logging.info(f"Saving with loss of {epoch_val_loss}, improved over previous {best_loss}")

    def bbox_iou(self, true_boxes, pred_boxes):
        iou_list = []
        for true_box, pred_box in zip(true_boxes, pred_boxes):

            x_left = max(true_box[0], pred_box[0]).item()
            y_top = max(true_box[1], pred_box[1]).item()

            x_right = min(true_box[2], pred_box[2]).item()
            y_bottom = min(true_box[3], pred_box[3]).item()

            if x_right < x_left or y_bottom < y_top:
                return 0.0

            overlap = (x_right - x_left) * (y_bottom - y_top)

            true_box_area = (true_box[2] - true_box[0]) * (true_box[3] - true_box[1])
            pred_box_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
            iou = overlap / float(true_box_area + pred_box_area - overlap)
            iou_list.append(iou)

        iou = torch.tensor(iou)
        iou = torch.mean(iou)

        return iou

    def draw_boxes(self, images, bboxes, labels):
        label_dict = {0: "Cat", 1: "Dog"}

        for batch in zip(images, bboxes, labels):
            cv2.destroyAllWindows()
            image, bbox, label = batch[0].cpu().numpy(), batch[1].cpu().numpy(), torch.argmax(batch[2]).cpu().item()
            image = np.rollaxis(image, 0, 3)
            image = ((image - image.min()) * (1 / (image.max() - image.min()) * 255)).astype("uint8")
            image = cv2.UMat(image)

            cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), thickness=2)

            cv2.putText(
                image,
                f"{label_dict[label]}",
                (bbox[1], bbox[3]),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (0, 255, 0),
                1,
                cv2.LINE_AA,
            )
            cv2.imshow("test", image)
            cv2.waitKey(1)
            sleep(1)
            cv2.destroyAllWindows()

    def forward_pass(self, dataloader, draw=False, train=False):
        def get_loss(inputs, bbox_labels, clf_labels):
            label_outputs, bbox_outputs = self.net(inputs)
            bbox_loss = self.loss_function(bbox_outputs, bbox_labels)
            clf_loss = self.clf_loss_function(label_outputs, clf_labels)
            loss = torch.mean(bbox_loss + clf_loss * self.lmb)
            return loss, label_outputs, bbox_outputs

        if train:
            self.net.train()
        else:
            self.net.eval()

        losses = []
        accuracies = []
        miou = []

        for step, batch in enumerate(dataloader):
            inputs = batch[0].to(self.device).float()
            labels = batch[1].to(self.device).float()

            # splitting labels for separate loss calculation
            bbox_labels = labels[:, :4]
            clf_labels = labels[:, 4:].long()
            clf_labels = clf_labels[:, 0]

            with autocast():
                if train:
                    loss, label_outputs, bbox_outputs = get_loss(inputs, bbox_labels, clf_labels)
                    self.scaler.scale(loss).backward()
                else:
                    with torch.no_grad():
                        loss, label_outputs, bbox_outputs = get_loss(inputs, bbox_labels, clf_labels)
                    if draw:
                        self.draw_boxes(inputs, bbox_outputs, label_outputs)

            matches = [torch.argmax(i) == j for i, j in zip(label_outputs, clf_labels)]
            acc = matches.count(True) / len(matches)
            iou = self.bbox_iou(bbox_labels, bbox_outputs)

            miou.append(iou)
            losses.append(loss)
            accuracies.append(acc)

            if train and (step + 1) % self.gradient_accum_steps == 0:
                # gradient accumulation to train with bigger effective batch size
                # with less memory use
                # fp16 is used to speed up training and reduce memory consumption
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()
                logging.info(
                    f"Step {step} of {len(dataloader)},\t"
                    f"Accuracy: {sum(accuracies)/len(accuracies):.2f},\t"
                    f"mIoU: {sum(miou)/len(miou):.2f},\t"
                    f"Loss: {sum(losses)/len(losses):.2f}"
                )

        return losses, accuracies, miou
Esempio n. 24
0
def train_generator_PG(dis,  dis_option_encoder, state_encoder, reason_decoder, gen_optimizer, epochs, beta = 15, max_length = 25):
	"""
	The generator is trained using policy gradients, using the reward from the discriminator.
	Training is done for num_batches batches.
	"""
	gen_optimizer = Ranger(itertools.chain(state_encoder.parameters(), reason_decoder.parameters()), lr=1e-3,
						   weight_decay=1e-4)

	for epoch in tqdm(range(epochs)):
		sys.stdout.flush()
		total_mle_loss = 0
		total_label_loss = 0
		total_acc = 0
		count = 0
		for train_data in train_loader_adv:
			loss = 0
			gen_optimizer.zero_grad()
			mle_loss = 0
			label_loss = 0
			num, statement, reason, label, options = train_data
			reason = torch.tensor(reason).type(torch.cuda.LongTensor)
			batch_size, input_size = statement.size()
			state_input_lengths = [len(x) for x in statement]
			falselabel = [[0, 1, 2] for l in label]
			for i in range(len(label)): falselabel[i].remove(label[i])
			choice_label = [random.choice(ls) for ls in falselabel]

			input_lengths = [len(x) for x in statement]
			encoder_outputs, hidden = state_encoder(statement, input_lengths)

			decoder_input = torch.tensor([[word_encoder.encode("<sos>")] for i in range(batch_size)]).type(
				torch.cuda.LongTensor)

			encoder_outputs = encoder_outputs.permute(1, 0, 2)  # -> (T*B*H)
			target_length = reason.size()[1]
			criterion = nn.NLLLoss()
			decoder_outputs = [[] for i in range(batch_size)]
			eos_idx = word_encoder.encode("<eos>")
			reason_permute = reason.permute(1, 0)
			for di in range(max_length):
				output, hidden = reason_decoder(decoder_input, hidden, encoder_outputs)
				if di < target_length:
					mle_loss += criterion(output, reason_permute[di])
				topv, topi = output.topk(1)
				all_end = True
				for i in range(batch_size):
					# print(topi.squeeze()[i].item())
					if topi.squeeze()[i].item() != eos_idx:
						decoder_outputs[i].append(word_encoder.decode(np.array([topi.squeeze()[i].item()])))
						all_end = False
				if all_end:
					break
				decoder_input = topi.squeeze().detach()
			origin_decoder_outputs = decoder_outputs
			decoder_outputs = [" ".join(output) for output in decoder_outputs]


			correct_option = ["<sep>".join(options[label[i]][i]) for i in range(batch_size)]
			correct_option = [word_encoder.encode(option) for option in correct_option]
			for i in range(batch_size):
				options[choice_label[i]][i][1] = decoder_outputs[i]

			option_input_lengths = [[len(x) for x in option] for option in options]
			option_lens = [max(len(option[0]) + len(option[1]) + 1 for option in options[i]) for i in range(3)]
			options = [[word_encoder.encode("<sep>".join([x[0], x[1]])) for x in option] for option in options]
			options = [sequence.pad_sequences(option, maxlen=option_len, padding='post') for option, option_len in
					   zip(options, option_lens)]
			options = [torch.tensor(option).type(torch.cuda.LongTensor) for option in options]

			option_hiddens = []
			for i in range(3):
				encoder_outputs, option_hidden = dis_option_encoder(options[i], option_input_lengths[i])
				option_hiddens.append(option_hidden)

			out = dis(batch_size, option_hiddens)

			for i in range(batch_size):
				for j in range(min(len(origin_decoder_outputs[i]),correct_option[i].size()[0])):
					false_one_hot = torch.zeros(1,3)
					false_one_hot[0][falselabel[i]] = 1
					false_one_hot = false_one_hot.type(torch.cuda.FloatTensor)
					label_loss -= beta * out[i].mul(false_one_hot).sum()


			total_label_loss += label_loss
			total_mle_loss += mle_loss
			loss = label_loss + mle_loss
			loss.backward()
			gen_optimizer.step()
			count += 1
		total_mle_loss = total_mle_loss / (count * BATCH_SIZE)
		total_label_loss = total_label_loss / (count * BATCH_SIZE)

		print('\n average_train_NLL = ', total_mle_loss,' the label loss = ', total_label_loss)
Esempio n. 25
0
def train_fold():
    #get arguments
    opts=get_args()

    #gpu selection
    os.environ["CUDA_VISIBLE_DEVICES"] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #instantiate datasets
    json_path=os.path.join(opts.path,'train.json')

    json=pd.read_json(json_path,lines=True)
    json=json[json.signal_to_noise > opts.noise_filter]
    ids=np.asarray(json.id.to_list())


    error_weights=get_errors(json)
    error_weights=opts.error_alpha+np.exp(-error_weights*opts.error_beta)
    train_indices,val_indices=get_train_val_indices(json,opts.fold,SEED=2020,nfolds=opts.nfolds)

    _,labels=get_data(json)
    sequences=np.asarray(json.sequence)
    train_seqs=sequences[train_indices]
    val_seqs=sequences[val_indices]
    train_labels=labels[train_indices]
    val_labels=labels[val_indices]
    train_ids=ids[train_indices]
    val_ids=ids[val_indices]
    train_ew=error_weights[train_indices]
    val_ew=error_weights[val_indices]

    #train_inputs=np.stack([train_inputs],0)
    #val_inputs=np.stack([val_inputs,val_inputs2],0)
    dataset=RNADataset(train_seqs,train_labels,train_ids, train_ew, opts.path)
    val_dataset=RNADataset(val_seqs,val_labels, val_ids, val_ew, opts.path, training=False)
    dataloader = DataLoader(dataset, batch_size=opts.batch_size,
                            shuffle=True, num_workers=opts.workers)
    val_dataloader = DataLoader(val_dataset, batch_size=opts.batch_size*2,
                            shuffle=False, num_workers=opts.workers)

    # print(dataset.data.shape)
    # print(dataset.bpps[0].shape)
    # exit()
    #checkpointing
    checkpoints_folder='checkpoints_fold{}'.format((opts.fold))
    csv_file='log_fold{}.csv'.format((opts.fold))
    columns=['epoch','train_loss',
             'val_loss']
    logger=CSVLogger(columns,csv_file)

    #build model and logger
    model=NucleicTransformer(opts.ntoken, opts.nclass, opts.ninp, opts.nhead, opts.nhid,
                           opts.nlayers, opts.kmer_aggregation, kmers=opts.kmers,stride=opts.stride,
                           dropout=opts.dropout).to(device)
    optimizer=Ranger(model.parameters(), weight_decay=opts.weight_decay)
    criterion=weighted_MCRMSE
    #lr_schedule=lr_AIAYN(optimizer,opts.ninp,opts.warmup_steps,opts.lr_scale)

    # Mixed precision initialization
    opt_level = 'O1'
    #model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    model = nn.DataParallel(model)
    pretrained_df=pd.read_csv('pretrain.csv')
    #print(pretrained_df.epoch[-1])
    model.load_state_dict(torch.load('pretrain_weights/epoch{}.ckpt'.format(int(pretrained_df.iloc[-1].epoch))))

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print('Total number of paramters: {}'.format(pytorch_total_params))


    #distance_mask=get_distance_mask(107)
    #distance_mask=torch.tensor(distance_mask).float().to(device).reshape(1,107,107)
    #print("Starting training for fold {}/{}".format(opts.fold,opts.nfolds))
    #training loop
    cos_epoch=int(opts.epochs*0.75)-1
    lr_schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,(opts.epochs-cos_epoch)*len(dataloader))
    for epoch in range(opts.epochs):
        model.train(True)
        t=time.time()
        total_loss=0
        optimizer.zero_grad()
        train_preds=[]
        ground_truths=[]
        step=0
        for data in dataloader:
        #for step in range(1):
            step+=1
            #lr=lr_schedule.step()
            lr=get_lr(optimizer)
            #print(lr)
            src=data['data'].to(device)
            labels=data['labels']
            bpps=data['bpp'].to(device)
            #print(bpps.shape[1])
            # bpp_selection=np.random.randint(bpps.shape[1])
            # bpps=bpps[:,bpp_selection]
            # src=src[:,bpp_selection]

            # print(bpps.shape)
            # print(src.shape)
            # exit()

            # print(bpps.shape)
            # exit()
            #src=mutate_rna_input(src,opts.nmute)
            #src=src.long()[:,np.random.randint(2)]
            labels=labels.to(device)#.float()
            output=model(src,bpps)
            ew=data['ew'].to(device)
            #print(output.shape)
            #print(labels.shape)
            loss=criterion(output[:,:68],labels,ew).mean()

            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            total_loss+=loss
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.3f} Lr:{:.6f} Time: {:.1f}"
                           .format(epoch+1, opts.epochs, step+1, len(dataloader), total_loss/(step+1) , lr,time.time()-t),end='\r',flush=True) #total_loss/(step+1)
            #break
            if epoch > cos_epoch:
                lr_schedule.step()
        print('')
        train_loss=total_loss/(step+1)
        #recon_acc=np.sum(recon_preds==true_seqs)/len(recon_preds)
        torch.cuda.empty_cache()
        if (epoch+1)%opts.val_freq==0 and epoch > cos_epoch:
        #if (epoch+1)%opts.val_freq==0:
            val_loss=validate(model,device,val_dataloader,batch_size=opts.batch_size)
            to_log=[epoch+1,train_loss,val_loss,]
            logger.log(to_log)


        if (epoch+1)%opts.save_freq==0:
            save_weights(model,optimizer,epoch,checkpoints_folder)

        # if epoch == cos_epoch:
        #     print('yes')


    get_best_weights_from_fold(opts.fold)
Esempio n. 26
0
def train(fold_idx=None):
    # model = UNet(n_classes=1, n_channels=3)
    model = DeepLabV3_plus(num_classes=1, backbone='resnet', sync_bn=True)
    train_dataloader, valid_dataloader = get_trainval_dataloader()
    criterion = nn.BCEWithLogitsLoss()
    optimizer = Ranger(model.parameters(), lr=1e-3, weight_decay=0.0005)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                                  T_max=10)

    best_val_score = 0
    last_improved_epoch = 0
    if fold_idx is None:
        print('start')
        model_save_path = os.path.join(config.dir_weight,
                                       '{}.bin'.format(config.save_model_name))
    else:
        print('start fold: {}'.format(fold_idx + 1))
        model_save_path = os.path.join(
            config.dir_weight, '{}_fold{}.bin'.format(config.save_model_name,
                                                      fold_idx))
    for cur_epoch in range(config.num_epochs):
        start_time = int(time.time())
        model.train()
        print('epoch: ', cur_epoch + 1)
        cur_step = 0
        for batch in train_dataloader:
            batch_x = batch['image']
            batch_y = batch['mask']
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            mask_pred = model(batch_x)
            train_loss = criterion(mask_pred, batch_y)
            train_loss.backward()
            optimizer.step()

            cur_step += 1
            if cur_step % config.step_train_print == 0:
                train_acc = accuracy(mask_pred, batch_y)
                msg = 'the current step: {0}/{1}, train loss: {2:>5.2}, train acc: {3:>6.2%}'
                print(
                    msg.format(cur_step, len(train_dataloader),
                               train_loss.item(), train_acc[0].item()))

        val_miou = eval_net_unet_miou(model, valid_dataloader, device)
        val_score = val_miou
        if val_score > best_val_score:
            best_val_score = val_score
            torch.save(model.state_dict(), model_save_path)
            improved_str = '*'
            last_improved_epoch = cur_epoch
        else:
            improved_str = ''
        msg = 'the current epoch: {0}/{1}, val score: {3:>6.2%}, cost: {4}s {5}'
        end_time = int(time.time())
        print(
            msg.format(cur_epoch + 1, config.num_epochs, val_score,
                       end_time - start_time, improved_str))
        if cur_epoch - last_improved_epoch > config.num_patience_epoch:
            print("No optimization for a long time, auto-stopping...")
            break
        scheduler_cosine.step()
    del model
    gc.collect()
    return best_val_score