for i in range(0, out_seq_len): y_pred[:, i] = model() loss = criterion(y_pred, Y) loss.backward() clip_grads(model) optimizer.step() losses += [loss.item()] if e % 50 == 0: mean_loss = np.array(losses[-50:]).mean() print("Loss: ", loss.item()) writer.add_scalar('Mean loss', loss.item(), e) if e % 1000 == 0: for name, param in model.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), e) mem_pic, read_pic, write_pic = model.get_memory_info() pic1 = vutils.make_grid(y_pred, normalize=True, scale_each=True) pic2 = vutils.make_grid(Y, normalize=True, scale_each=True) pic3 = vutils.make_grid(mem_pic, normalize=True, scale_each=True) pic4 = vutils.make_grid(read_pic, normalize=True, scale_each=True) pic5 = vutils.make_grid(write_pic, normalize=True, scale_each=True) #writer.add_image('NTM output', pic1, e) #writer.add_image('True output', pic2, e) #writer.add_image('Memory', pic3, e) #writer.add_image('Read weights', pic4, e) #writer.add_image('Write weights', pic5, e) #torch.save(model.state_dict(), args.savemodel) losses = []
loss = criterion(y_pred, Y.cuda()) loss.backward() clip_grads(model) optimizer.step() losses += [loss.item()] if e % 200 == 0: mean_loss = np.array(losses[-200:]).mean() print("Loss: ", loss.item()) writer.add_scalar('Mean loss', mean_loss, e) losses = [] if e % 5000 == 0: print(y_pred) print(Y) mem_pic, read_pic, write_pic, ntm_programs = model.get_memory_info() gen_programs = dataset.program_list() pic1 = vutils.make_grid(y_pred, normalize=True, scale_each=True) pic2 = vutils.make_grid(Y, normalize=True, scale_each=True) pic3 = vutils.make_grid(mem_pic, normalize=True, scale_each=True) pic4 = vutils.make_grid(read_pic, normalize=True, scale_each=True) pic5 = vutils.make_grid(write_pic, normalize=True, scale_each=True) writer.add_image('NTM output', colorize(pic1.data), e) writer.add_image('True output', colorize(pic2), e) writer.add_image('Memory', pic3, e) writer.add_image('Read addressing', pic4, e) writer.add_image('Write addressing', pic5, e) for i, ntm_program in enumerate(ntm_programs): pic6 = vutils.make_grid(ntm_program.view(args.function_size, args.function_size), normalize=True, scale_each=True) writer.add_image('NTM software learned ' + str(i), pic6, e)